ML Final Project Yanchen Dong¶

Step 1 Problem Statement¶

A Brain tumor is considered as one of the aggressive diseases, among children and adults. Application of automated classification techniques using Machine Learning(ML) and Artificial Intelligence(AI)has consistently shown higher accuracy than manual classification. Hence, proposing a system performing detection and classification by using Deep Learning Algorithms using ConvolutionNeural Network (CNN), Artificial Neural Network (ANN), and TransferLearning (TL) would be helpful to doctors all around the world.

The gold here is to identify tumor type among 'glioma_tumor','no_tumor','meningioma_tumor','pituitary_tumor'.

Step 2 Assumptions/Hypotheses about data and model¶

To Detect and Classify Brain Tumor using, CNN and TL; as an asset of Deep Learning and to examine the tumor position(segmentation).

Tensorflow CNN-based Brain Tumor Detection will be used.

In [1]:
import numpy as np
import os
import keras
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import plotly.express as px
import matplotlib.colors

import seaborn as sns
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.utils import shuffle 
from tensorflow.keras.utils import to_categorical

### Creating the CNN Model
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.layers import Input, Dense, InputLayer, Flatten, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from keras.models import Model, Sequential
from tensorflow.keras.models import Sequential, load_model
from keras import metrics
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score


# Building Model
from keras.utils import plot_model
from tensorflow.keras import models

# Training Model
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
C:\Users\yanch\anaconda3\Lib\site-packages\paramiko\transport.py:219: CryptographyDeprecationWarning: Blowfish has been deprecated
  "class": algorithms.Blowfish,
In [2]:
colors_dark = ["#1F1F1F", "#313131", '#636363', '#AEAEAE', '#DADADA']
colors_red = ["#331313", "#582626", '#9E1717', '#D35151', '#E9B4B4']
colors_green = ['#01411C','#4B6F44','#4F7942','#74C365','#D0F0C0']

Step 3: EDA¶

Step 4 Feature Engineering & Transformations¶

In [2]:
base_dir = 'C:\\Users\\yanch\\Desktop\\UC\\Classes\\2024 Spring\\ADSP 31009 Machine Learning and Predictive Analytics\\Final Project'
train_dir = os.path.join(base_dir, 'Training')
test_dir = os.path.join(base_dir, 'Testing')
labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
In [3]:
from skimage.transform import resize
X_train = [] #Training Dataset
Y_train = [] #Training Labels

image_size=224

for label in labels:
    path = os.path.join(train_dir, label)
    class_num = labels.index(label)
    for img in os.listdir(path):
        img_array = plt.imread(os.path.join(path, img))
        img_resized = resize(img_array, (image_size, image_size, 3))
        X_train.append(img_resized)
        Y_train.append(class_num)

for label in labels:
    path = os.path.join(test_dir, label)
    class_num = labels.index(label)
    for img in os.listdir(path):
        img_array = plt.imread(os.path.join(path, img))
        img_resized = resize(img_array, (image_size, image_size, 3))
        X_train.append(img_resized)
        Y_train.append(class_num)
        
X_train = np.array(X_train)
Y_train = np.array(Y_train)
        
In [4]:
# Data generators
train_datagen = ImageDataGenerator(rescale=1/255,
                                   rotation_range=90,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True,
                                   vertical_flip=True,
                                   validation_split=0.2)

valid_datagen = ImageDataGenerator(rescale=1/255, validation_split=0.2)

train_generator=train_datagen.flow_from_directory(train_dir,
                                                  target_size=(224,224), color_mode='rgb', shuffle=True,
                                                  subset='training', batch_size=32, class_mode='categorical')

val_generator = valid_datagen.flow_from_directory(train_dir,
                                                  target_size=(224,224), color_mode='rgb', shuffle=True,
                                                  subset='validation',batch_size=32,class_mode='categorical')
Found 2297 images belonging to 4 classes.
Found 573 images belonging to 4 classes.
In [5]:
X_train.shape
Out[5]:
(3264, 224, 224, 3)
In [6]:
# Shuffling data
X_train, Y_train = shuffle(X_train, Y_train, random_state=42)
In [7]:
#After shuffling sample size remains same
X_train.shape
Out[7]:
(3264, 224, 224, 3)
3.1 Distribution of categories¶
In [7]:
# This method uses the classes array, which directly indicates the class index for each image
(unique, counts) = np.unique(train_generator.classes, return_counts=True)
class_counts = dict(zip(unique, counts))

# Mapping index to class names
class_names = {v: k for k, v in train_generator.class_indices.items()}
class_counts_named = {class_names[k]: v for k, v in class_counts.items()}

# Plotting
plt.figure(figsize=(10, 5))
plt.bar(class_counts_named.keys(), class_counts_named.values())
plt.title('Distribution of Classes in Training Data')
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.xticks(rotation=45)

plt.show()
3.2 Pixel Intensity Distribution¶

Analyzing the distribution of pixel intensities can help in understanding the general characteristics of the images, like contrast and brightness, and might suggest necessary preprocessing steps like histogram equalization.

In [9]:
fig, ax = plt.subplots()
for img in X_train[:5]:  # Use the same images from the first batch
    sns.histplot(img.ravel(), label='Pixel Intensity', ax=ax, kde=True)
ax.set_title('Pixel Intensity Distribution')
ax.legend()
plt.show()
3.3 Sample images after transformation¶
In [11]:
#plotting the images
plt.figure(figsize=(20,20))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(X_train[i])
    plt.title(labels[Y_train[i]], fontsize=16, fontweight='bold')
    plt.axis("off")
plt.show()
In [10]:
# Split the data into training and testing and validation
X_train, X_test, Y_train, Y_test = train_test_split(X_train, Y_train, test_size=0.2, random_state=42)
X_train, X_valid, Y_train, Y_valid = train_test_split(X_train, Y_train, test_size=0.1, random_state=42)
In [11]:
print(X_train.shape)
print(X_valid.shape)
print(X_test.shape)
print(Y_train.shape)
print(Y_test.shape)
print(Y_valid.shape)
(2349, 224, 224, 3)
(262, 224, 224, 3)
(653, 224, 224, 3)
(2349,)
(653,)
(262,)
3.4 Distribution of categories in train¶
In [82]:
# Count the number of images in each class
class_counts = np.bincount(Y_train)
class_names = ['glioma', 'meningioma', 'no tumor', 'pituitary']

# Create a DataFrame with class names and counts
train_df = pd.DataFrame({'Class': class_names, 'Count': class_counts})

# Create a bar chart using matplotlib
fig, ax = plt.subplots()

# Plot the bar chart
ax.barh(train_df['Class'], train_df['Count'])

# Add title and labels
ax.set_title('Number of Images in Each Class of the Train Data')
ax.set_xlabel('Count')
ax.set_ylabel('Class')

# Display the plot
plt.show()
In [33]:
# convert string to categorical
from keras.utils import to_categorical
y_train_new = []
y_valid_new = []
y_test_new = []

for i in range(len(Y_train)):
    y_train_new.append(to_categorical(Y_train[i], num_classes=4))

for i in range(len(Y_valid)):
    y_valid_new.append(to_categorical(Y_valid[i], num_classes=4))

for i in range(len(Y_test)):
    y_test_new.append(to_categorical(Y_test[i], num_classes=4))

y_train_new = np.array(y_train_new)
y_valid_new = np.array(y_valid_new)
y_test_new = np.array(y_test_new)
In [35]:
y_train_new.shape
Out[35]:
(2349, 4)
In [37]:
y_test_new.shape
Out[37]:
(653, 4)

Step 5 Proposed Model (No Regularization)¶

Training Loss (Blue Line):

The training loss starts high and decreases sharply, flattening out around epoch 10. This indicates that the model is effectively learning from the training data and minimizing the loss.

Validation Loss (Orange Line):

The validation loss also starts high, decreasing rapidly at first and then fluctuating but stabilizing after epoch 10. The validation loss remains consistently low, indicating good generalization to the validation data despite some fluctuations.

The model shows strong overall performance with a high accuracy of 0.92 and consistently high precision, recall, and f1-scores across all classes.

Both training and validation losses decrease steadily and stabilize, indicating that the model is learning effectively and generalizing well to the validation set.

The fluctuations in validation loss suggest some variability, but the overall trend remains low.

In [38]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import InputLayer, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Model architecture without regularization
model = Sequential()
model.add(InputLayer(input_shape=(image_size, image_size, 3)))

model.add(Conv2D(16, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(256, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(4, activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Summary of the model
model.summary()
C:\Users\yanch\anaconda3\Lib\site-packages\keras\src\layers\core\input_layer.py:25: UserWarning: Argument `input_shape` is deprecated. Use `shape` instead.
  warnings.warn(
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_10 (Conv2D)              │ (None, 220, 220, 16)   │         1,216 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_10 (MaxPooling2D) │ (None, 110, 110, 16)   │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_11 (Conv2D)              │ (None, 108, 108, 32)   │         4,640 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_11 (MaxPooling2D) │ (None, 54, 54, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_12 (Conv2D)              │ (None, 52, 52, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_12 (MaxPooling2D) │ (None, 26, 26, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_13 (Conv2D)              │ (None, 24, 24, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_13 (MaxPooling2D) │ (None, 12, 12, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_14 (Conv2D)              │ (None, 10, 10, 256)    │       295,168 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_14 (MaxPooling2D) │ (None, 5, 5, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_2 (Flatten)             │ (None, 6400)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_9 (Dense)                 │ (None, 512)            │     3,277,312 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_10 (Dense)                │ (None, 4)              │         2,052 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 3,672,740 (14.01 MB)
 Trainable params: 3,672,740 (14.01 MB)
 Non-trainable params: 0 (0.00 B)
In [15]:
history = model.fit(X_train, y_train_new, 
                    batch_size=64, 
                    epochs=35, 
                    steps_per_epoch=100,
                    validation_data=(X_valid, y_valid_new))
Epoch 1/35
 37/100 ━━━━━━━━━━━━━━━━━━━━ 48s 773ms/step - accuracy: 0.3965 - loss: 1.2695
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
  self.gen.throw(typ, value, traceback)
100/100 ━━━━━━━━━━━━━━━━━━━━ 35s 299ms/step - accuracy: 0.4503 - loss: 1.1917 - val_accuracy: 0.5992 - val_loss: 0.9012
Epoch 2/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 24s 238ms/step - accuracy: 0.6621 - loss: 0.8256 - val_accuracy: 0.7099 - val_loss: 0.6280
Epoch 3/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 25s 243ms/step - accuracy: 0.7634 - loss: 0.6076 - val_accuracy: 0.6908 - val_loss: 0.7620
Epoch 4/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 218ms/step - accuracy: 0.7917 - loss: 0.5103 - val_accuracy: 0.8168 - val_loss: 0.4132
Epoch 5/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.8925 - loss: 0.2956 - val_accuracy: 0.7824 - val_loss: 0.5028
Epoch 6/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 222ms/step - accuracy: 0.9123 - loss: 0.2329 - val_accuracy: 0.8092 - val_loss: 0.4947
Epoch 7/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 217ms/step - accuracy: 0.9361 - loss: 0.1684 - val_accuracy: 0.8893 - val_loss: 0.4255
Epoch 8/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 219ms/step - accuracy: 0.9571 - loss: 0.1243 - val_accuracy: 0.8626 - val_loss: 0.4455
Epoch 9/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 221ms/step - accuracy: 0.9795 - loss: 0.0562 - val_accuracy: 0.8931 - val_loss: 0.4505
Epoch 10/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9851 - loss: 0.0488 - val_accuracy: 0.8740 - val_loss: 0.5628
Epoch 11/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9678 - loss: 0.0820 - val_accuracy: 0.8969 - val_loss: 0.4931
Epoch 12/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9799 - loss: 0.0519 - val_accuracy: 0.8817 - val_loss: 0.5548
Epoch 13/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9846 - loss: 0.0475 - val_accuracy: 0.8931 - val_loss: 0.6534
Epoch 14/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 226ms/step - accuracy: 0.9886 - loss: 0.0339 - val_accuracy: 0.8855 - val_loss: 0.6284
Epoch 15/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 213ms/step - accuracy: 0.9918 - loss: 0.0278 - val_accuracy: 0.9084 - val_loss: 0.5133
Epoch 16/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 212ms/step - accuracy: 0.9927 - loss: 0.0279 - val_accuracy: 0.9084 - val_loss: 0.5029
Epoch 17/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 24s 234ms/step - accuracy: 0.9956 - loss: 0.0128 - val_accuracy: 0.9275 - val_loss: 0.5578
Epoch 18/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9986 - loss: 0.0081 - val_accuracy: 0.8969 - val_loss: 0.7537
Epoch 19/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 225ms/step - accuracy: 0.9966 - loss: 0.0129 - val_accuracy: 0.9008 - val_loss: 0.5796
Epoch 20/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 216ms/step - accuracy: 0.9959 - loss: 0.0110 - val_accuracy: 0.9046 - val_loss: 0.6764
Epoch 21/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 24s 236ms/step - accuracy: 0.9989 - loss: 0.0064 - val_accuracy: 0.9084 - val_loss: 0.6961
Epoch 22/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 216ms/step - accuracy: 0.9993 - loss: 0.0053 - val_accuracy: 0.9160 - val_loss: 0.6376
Epoch 23/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 219ms/step - accuracy: 0.9992 - loss: 0.0044 - val_accuracy: 0.9160 - val_loss: 0.5882
Epoch 24/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 229ms/step - accuracy: 0.9994 - loss: 0.0037 - val_accuracy: 0.9198 - val_loss: 0.5992
Epoch 25/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9987 - loss: 0.0043 - val_accuracy: 0.9198 - val_loss: 0.5866
Epoch 26/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 223ms/step - accuracy: 0.9978 - loss: 0.0045 - val_accuracy: 0.9237 - val_loss: 0.5753
Epoch 27/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 216ms/step - accuracy: 0.9991 - loss: 0.0075 - val_accuracy: 0.9198 - val_loss: 0.5439
Epoch 28/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 221ms/step - accuracy: 0.9987 - loss: 0.0040 - val_accuracy: 0.9160 - val_loss: 0.5814
Epoch 29/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 222ms/step - accuracy: 0.9981 - loss: 0.0045 - val_accuracy: 0.9198 - val_loss: 0.5544
Epoch 30/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9993 - loss: 0.0044 - val_accuracy: 0.9160 - val_loss: 0.5499
Epoch 31/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 219ms/step - accuracy: 0.9992 - loss: 0.0028 - val_accuracy: 0.9160 - val_loss: 0.5503
Epoch 32/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 213ms/step - accuracy: 0.9992 - loss: 0.0031 - val_accuracy: 0.9160 - val_loss: 0.5302
Epoch 33/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9989 - loss: 0.0029 - val_accuracy: 0.9160 - val_loss: 0.5438
Epoch 34/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 225ms/step - accuracy: 0.9989 - loss: 0.0026 - val_accuracy: 0.9160 - val_loss: 0.5614
Epoch 35/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 213ms/step - accuracy: 0.9987 - loss: 0.0036 - val_accuracy: 0.9160 - val_loss: 0.5497
In [16]:
# Save the model
# this is baseline model with rotation range = 20
model.save('cnn_model_1.keras')
In [17]:
import matplotlib.pyplot as plt

# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
In [18]:
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 80ms/step
Val Accuracy = 0.9160
In [19]:
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 102ms/step
Test Accuracy = 0.9250
In [20]:
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
               precision    recall  f1-score   support

           0       0.91      0.90      0.91       198
           1       0.89      0.91      0.90       183
           2       0.95      0.88      0.92       104
           3       0.96      0.99      0.98       168

    accuracy                           0.92       653
   macro avg       0.93      0.92      0.93       653
weighted avg       0.93      0.92      0.92       653

Step 6 Proposed Model with regularization¶

  1. Input Layer:

It initializes the model to accept input images of size (image_size, image_size, 3), which corresponds to image height, image width, and 3 color channels (RGB).

  1. First Convolutional Block:

Conv2D with 16 filters: Applies a 5x5 convolution kernel to extract features such as edges and textures. The use of 16 filters means it will output 16 different feature maps.

BatchNormalization: Normalizes the activations from the previous layer, which helps in accelerating the training process and stabilizing the learning by normalizing the input layer by re-centering and re-scaling.

MaxPooling2D: Reduces the spatial dimensions (height and width) of the input volume to the next layer by taking the maximum value over a 2x2 pooling window. This helps in reducing the computational cost and overfitting by providing an abstracted form of the representation.

Dropout (0.2): Randomly sets the outgoing edges of 20% of the neurons to zero during training, to prevent overfitting.

  1. Subsequent Convolutional Blocks:

These blocks increase in the number of filters (32, 64, 128, 256). Increasing the number of filters allows the network to capture more complex patterns like textures and shapes.

Each block follows a similar structure: a convolution layer, batch normalization, max pooling, and dropout. This repeated structure helps the network in learning hierarchically more complex features at each level.

Kernel sizes are generally smaller (3x3) in subsequent layers, which is common as deeper layers capture higher-level abstract features where finer granularity is less important.

  1. Flattening:

The output of the final convolutional layer is flattened (converted from a matrix to a vector), so it can be fed into the dense layers.

  1. Dense Layers:

Dense Layer with 512 neurons: This layer is fully connected and uses ReLU activation. It serves as a classifier on the features formed by the convolutions and pooling layers. Dropout (0.2): Again used here to reduce overfitting.

  1. Output Layer:

Dense layer with 4 neurons: This implies the model is intended for a classification task with 4 classes. The softmax activation function is used to output a probability distribution over the 4 classes.

  1. Compilation:

The model uses the Adam optimizer, a popular choice for deep learning tasks as it combines the best properties of the AdaGrad and RMSProp algorithms to optimize its weights.

The loss function is categorical_crossentropy, suitable for multi-class classification problems.

The metric used to evaluate the model is accuracy.

In [59]:
#simple CNN per with augment
model = Sequential()
model.add(InputLayer(input_shape=(image_size, image_size,3)))

model.add(Conv2D(16, kernel_size=(5, 5), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Conv2D(64, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(Conv2D(128, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(Conv2D(256, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(4, activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
C:\Users\yanch\anaconda3\Lib\site-packages\keras\src\layers\core\input_layer.py:25: UserWarning:

Argument `input_shape` is deprecated. Use `shape` instead.

Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_10 (Conv2D)              │ (None, 220, 220, 16)   │         1,216 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_4           │ (None, 220, 220, 16)   │            64 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_10 (MaxPooling2D) │ (None, 110, 110, 16)   │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_10 (Dropout)            │ (None, 110, 110, 16)   │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_11 (Conv2D)              │ (None, 108, 108, 32)   │         4,640 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_5           │ (None, 108, 108, 32)   │           128 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_11 (MaxPooling2D) │ (None, 54, 54, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_12 (Conv2D)              │ (None, 52, 52, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_12 (MaxPooling2D) │ (None, 26, 26, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_11 (Dropout)            │ (None, 26, 26, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_13 (Conv2D)              │ (None, 24, 24, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_13 (MaxPooling2D) │ (None, 12, 12, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_12 (Dropout)            │ (None, 12, 12, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_14 (Conv2D)              │ (None, 10, 10, 256)    │       295,168 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_14 (MaxPooling2D) │ (None, 5, 5, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_13 (Dropout)            │ (None, 5, 5, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_2 (Flatten)             │ (None, 6400)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_4 (Dense)                 │ (None, 512)            │     3,277,312 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_14 (Dropout)            │ (None, 512)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_5 (Dense)                 │ (None, 4)              │         2,052 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 3,672,932 (14.01 MB)
 Trainable params: 3,672,836 (14.01 MB)
 Non-trainable params: 96 (384.00 B)
6.1 epochs=10, steps_per_epoch=5¶

The disparity between training and validation/test accuracy along with the low test scores suggests the model may be underfitting, as it does not perform well on any of the datasets.

In [38]:
history = model.fit(X_train, y_train_new, 
                    batch_size=64, 
                    epochs=10, 
                    steps_per_epoch=5,
                    validation_data=(X_valid, y_valid_new))
Epoch 1/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9951 - loss: 0.0164 - val_accuracy: 0.9389 - val_loss: 0.5294
Epoch 2/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9983 - loss: 0.0069 - val_accuracy: 0.9313 - val_loss: 0.5413
Epoch 3/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9803 - loss: 0.0544 - val_accuracy: 0.9351 - val_loss: 0.5765
Epoch 4/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9889 - loss: 0.0235 - val_accuracy: 0.9389 - val_loss: 0.6128
Epoch 5/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9902 - loss: 0.0434 - val_accuracy: 0.9389 - val_loss: 0.6036
Epoch 6/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9746 - loss: 0.1208 - val_accuracy: 0.9351 - val_loss: 0.4841
Epoch 7/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9867 - loss: 0.0426 - val_accuracy: 0.9275 - val_loss: 0.5177
Epoch 8/10
2/5 ━━━━━━━━━━━━━━━━━━━━ 2s 700ms/step - accuracy: 0.9752 - loss: 0.0742
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning:

Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.

5/5 ━━━━━━━━━━━━━━━━━━━━ 3s 405ms/step - accuracy: 0.9791 - loss: 0.0702 - val_accuracy: 0.9237 - val_loss: 0.4940
Epoch 9/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9889 - loss: 0.0358 - val_accuracy: 0.9084 - val_loss: 0.5404
Epoch 10/10
5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9845 - loss: 0.0548 - val_accuracy: 0.9198 - val_loss: 0.5243
In [24]:
# Save the model
# this is baseline model with rotation range = 20
model.save('new_cnn_model_1.keras')
In [39]:
import matplotlib.pyplot as plt

# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
In [29]:
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 124ms/step
Val Accuracy = 0.2672
In [30]:
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 3s 132ms/step
Test Accuracy = 0.2450
In [31]:
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00       219
           1       0.20      0.01      0.01       187
           2       0.17      0.80      0.29        87
           3       0.36      0.56      0.44       160

    accuracy                           0.25       653
   macro avg       0.18      0.34      0.18       653
weighted avg       0.17      0.25      0.15       653

C:\Users\yanch\anaconda3\Lib\site-packages\sklearn\metrics\_classification.py:1469: UndefinedMetricWarning:

Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

C:\Users\yanch\anaconda3\Lib\site-packages\sklearn\metrics\_classification.py:1469: UndefinedMetricWarning:

Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

C:\Users\yanch\anaconda3\Lib\site-packages\sklearn\metrics\_classification.py:1469: UndefinedMetricWarning:

Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

6.2 epochs=50, steps_per_epoch=50¶

Model Performance: The model demonstrates excellent learning capability and generalizes well to unseen data. The balance between training and validation performance suggests that the model configurations, including architecture adjustments, regularization techniques, or hyperparameters, are well-tuned.

Stability and Overfitting: The relatively smooth and convergent training and validation loss curves indicate that the model is stable and not overfitting. This is corroborated by the close tracking of validation loss to training loss.

The training and validation loss curves show a desirable behavior. Training loss steadily decreases, indicating good learning progress. Validation loss decreases alongside and remains close to the training loss, which is a good sign of the model not overfitting.

In [32]:
history = model.fit(X_train, y_train_new, 
                    batch_size=64, 
                    epochs=50, 
                    steps_per_epoch=50,
                    validation_data=(X_valid, y_valid_new))
Epoch 1/50
37/50 ━━━━━━━━━━━━━━━━━━━━ 12s 981ms/step - accuracy: 0.5946 - loss: 0.9527
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning:

Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.

50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 750ms/step - accuracy: 0.6032 - loss: 0.9389 - val_accuracy: 0.4008 - val_loss: 1.2256
Epoch 2/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 738ms/step - accuracy: 0.7068 - loss: 0.7582 - val_accuracy: 0.6145 - val_loss: 1.0696
Epoch 3/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 730ms/step - accuracy: 0.7281 - loss: 0.6811 - val_accuracy: 0.5496 - val_loss: 1.0466
Epoch 4/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 731ms/step - accuracy: 0.7868 - loss: 0.5578 - val_accuracy: 0.6374 - val_loss: 0.9908
Epoch 5/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 743ms/step - accuracy: 0.8076 - loss: 0.4926 - val_accuracy: 0.6718 - val_loss: 0.8159
Epoch 6/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 738ms/step - accuracy: 0.8473 - loss: 0.4301 - val_accuracy: 0.6908 - val_loss: 0.8416
Epoch 7/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 732ms/step - accuracy: 0.8699 - loss: 0.3598 - val_accuracy: 0.7443 - val_loss: 0.6991
Epoch 8/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 487s 10s/step - accuracy: 0.8710 - loss: 0.3516 - val_accuracy: 0.7595 - val_loss: 0.6429
Epoch 9/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 724ms/step - accuracy: 0.8879 - loss: 0.2956 - val_accuracy: 0.8130 - val_loss: 0.5012
Epoch 10/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 726ms/step - accuracy: 0.9031 - loss: 0.2607 - val_accuracy: 0.8588 - val_loss: 0.4286
Epoch 11/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 746ms/step - accuracy: 0.9337 - loss: 0.2054 - val_accuracy: 0.8702 - val_loss: 0.4034
Epoch 12/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 735ms/step - accuracy: 0.9091 - loss: 0.2255 - val_accuracy: 0.9008 - val_loss: 0.3470
Epoch 13/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 780ms/step - accuracy: 0.9400 - loss: 0.1757 - val_accuracy: 0.8817 - val_loss: 0.4146
Epoch 14/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 753ms/step - accuracy: 0.9507 - loss: 0.1474 - val_accuracy: 0.8817 - val_loss: 0.4246
Epoch 15/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 754ms/step - accuracy: 0.9528 - loss: 0.1394 - val_accuracy: 0.9084 - val_loss: 0.3712
Epoch 16/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 747ms/step - accuracy: 0.9588 - loss: 0.1137 - val_accuracy: 0.9122 - val_loss: 0.3470
Epoch 17/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 732ms/step - accuracy: 0.9605 - loss: 0.1113 - val_accuracy: 0.9237 - val_loss: 0.3751
Epoch 18/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 734ms/step - accuracy: 0.9551 - loss: 0.1080 - val_accuracy: 0.8931 - val_loss: 0.3842
Epoch 19/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 748ms/step - accuracy: 0.9650 - loss: 0.0972 - val_accuracy: 0.9160 - val_loss: 0.3153
Epoch 20/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 735ms/step - accuracy: 0.9733 - loss: 0.0816 - val_accuracy: 0.9237 - val_loss: 0.4023
Epoch 21/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 764ms/step - accuracy: 0.9799 - loss: 0.0745 - val_accuracy: 0.9313 - val_loss: 0.3908
Epoch 22/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 41s 817ms/step - accuracy: 0.9625 - loss: 0.0931 - val_accuracy: 0.9275 - val_loss: 0.3757
Epoch 23/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 41s 806ms/step - accuracy: 0.9834 - loss: 0.0574 - val_accuracy: 0.9313 - val_loss: 0.3387
Epoch 24/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 793ms/step - accuracy: 0.9730 - loss: 0.0750 - val_accuracy: 0.9313 - val_loss: 0.2972
Epoch 25/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 798ms/step - accuracy: 0.9832 - loss: 0.0579 - val_accuracy: 0.9160 - val_loss: 0.4324
Epoch 26/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 764ms/step - accuracy: 0.9682 - loss: 0.1009 - val_accuracy: 0.9237 - val_loss: 0.4243
Epoch 27/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 756ms/step - accuracy: 0.9774 - loss: 0.0579 - val_accuracy: 0.9351 - val_loss: 0.3809
Epoch 28/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 768ms/step - accuracy: 0.9821 - loss: 0.0566 - val_accuracy: 0.9427 - val_loss: 0.3873
Epoch 29/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 759ms/step - accuracy: 0.9894 - loss: 0.0359 - val_accuracy: 0.9427 - val_loss: 0.4556
Epoch 30/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 757ms/step - accuracy: 0.9863 - loss: 0.0442 - val_accuracy: 0.9160 - val_loss: 0.4866
Epoch 31/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 760ms/step - accuracy: 0.9858 - loss: 0.0473 - val_accuracy: 0.9313 - val_loss: 0.4852
Epoch 32/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 783ms/step - accuracy: 0.9872 - loss: 0.0445 - val_accuracy: 0.9389 - val_loss: 0.4071
Epoch 33/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 41s 810ms/step - accuracy: 0.9916 - loss: 0.0268 - val_accuracy: 0.9198 - val_loss: 0.5157
Epoch 34/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 761ms/step - accuracy: 0.9872 - loss: 0.0383 - val_accuracy: 0.9275 - val_loss: 0.4700
Epoch 35/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 765ms/step - accuracy: 0.9849 - loss: 0.0430 - val_accuracy: 0.9160 - val_loss: 0.4828
Epoch 36/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 755ms/step - accuracy: 0.9837 - loss: 0.0496 - val_accuracy: 0.9275 - val_loss: 0.5223
Epoch 37/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 756ms/step - accuracy: 0.9831 - loss: 0.0462 - val_accuracy: 0.9389 - val_loss: 0.4216
Epoch 38/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 757ms/step - accuracy: 0.9848 - loss: 0.0453 - val_accuracy: 0.9427 - val_loss: 0.4807
Epoch 39/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 749ms/step - accuracy: 0.9832 - loss: 0.0521 - val_accuracy: 0.9198 - val_loss: 0.6065
Epoch 40/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 771ms/step - accuracy: 0.9840 - loss: 0.0441 - val_accuracy: 0.9389 - val_loss: 0.5364
Epoch 41/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 42s 829ms/step - accuracy: 0.9882 - loss: 0.0404 - val_accuracy: 0.9313 - val_loss: 0.4418
Epoch 42/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 777ms/step - accuracy: 0.9881 - loss: 0.0338 - val_accuracy: 0.9427 - val_loss: 0.4785
Epoch 43/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 760ms/step - accuracy: 0.9925 - loss: 0.0230 - val_accuracy: 0.9580 - val_loss: 0.4233
Epoch 44/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 780ms/step - accuracy: 0.9899 - loss: 0.0393 - val_accuracy: 0.9427 - val_loss: 0.4645
Epoch 45/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 764ms/step - accuracy: 0.9915 - loss: 0.0212 - val_accuracy: 0.9389 - val_loss: 0.5971
Epoch 46/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 761ms/step - accuracy: 0.9898 - loss: 0.0299 - val_accuracy: 0.9504 - val_loss: 0.5630
Epoch 47/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 778ms/step - accuracy: 0.9895 - loss: 0.0299 - val_accuracy: 0.9427 - val_loss: 0.4841
Epoch 48/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 759ms/step - accuracy: 0.9892 - loss: 0.0384 - val_accuracy: 0.9466 - val_loss: 0.4732
Epoch 49/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 786ms/step - accuracy: 0.9900 - loss: 0.0299 - val_accuracy: 0.9466 - val_loss: 0.4952
Epoch 50/50
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 763ms/step - accuracy: 0.9898 - loss: 0.0312 - val_accuracy: 0.9427 - val_loss: 0.5609
In [33]:
# Save the model
# this is baseline model with rotation range = 20
model.save('new_cnn_model_2.keras')
In [34]:
import matplotlib.pyplot as plt

# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
In [35]:
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 101ms/step
Val Accuracy = 0.9427
In [36]:
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 110ms/step
Test Accuracy = 0.9173
In [37]:
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
               precision    recall  f1-score   support

           0       0.88      0.94      0.91       219
           1       0.95      0.82      0.88       187
           2       0.92      0.92      0.92        87
           3       0.94      0.99      0.96       160

    accuracy                           0.92       653
   macro avg       0.92      0.92      0.92       653
weighted avg       0.92      0.92      0.92       653

6.3 epochs=50, steps_per_epoch=100¶

Loss Stability Concerns: Despite excellent performance metrics, the variability in the validation loss could still be a concern. It may suggest that the model could start overfitting if trained for more epochs without adjustments.

In [40]:
history = model.fit(X_train, y_train_new, 
                    batch_size=64, 
                    epochs=50, 
                    steps_per_epoch=100,
                    validation_data=(X_valid, y_valid_new))
Epoch 1/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 394ms/step - accuracy: 0.9861 - loss: 0.0356 - val_accuracy: 0.9313 - val_loss: 0.5516
Epoch 2/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 515ms/step - accuracy: 0.9854 - loss: 0.0395 - val_accuracy: 0.9427 - val_loss: 0.4838
Epoch 3/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 49s 479ms/step - accuracy: 0.9863 - loss: 0.0466 - val_accuracy: 0.9389 - val_loss: 0.5167
Epoch 4/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 49s 481ms/step - accuracy: 0.9904 - loss: 0.0240 - val_accuracy: 0.9504 - val_loss: 0.4565
Epoch 5/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 48s 469ms/step - accuracy: 0.9896 - loss: 0.0271 - val_accuracy: 0.9351 - val_loss: 0.5200
Epoch 6/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 515ms/step - accuracy: 0.9911 - loss: 0.0332 - val_accuracy: 0.9427 - val_loss: 0.4462
Epoch 7/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 427ms/step - accuracy: 0.9931 - loss: 0.0218 - val_accuracy: 0.9427 - val_loss: 0.4205
Epoch 8/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 394ms/step - accuracy: 0.9918 - loss: 0.0275 - val_accuracy: 0.9313 - val_loss: 0.4520
Epoch 9/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 45s 448ms/step - accuracy: 0.9854 - loss: 0.0446 - val_accuracy: 0.9237 - val_loss: 0.5280
Epoch 10/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 419ms/step - accuracy: 0.9859 - loss: 0.0413 - val_accuracy: 0.9313 - val_loss: 0.5361
Epoch 11/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 396ms/step - accuracy: 0.9908 - loss: 0.0263 - val_accuracy: 0.9466 - val_loss: 0.5306
Epoch 12/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 44s 432ms/step - accuracy: 0.9872 - loss: 0.0373 - val_accuracy: 0.9389 - val_loss: 0.4706
Epoch 13/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 426ms/step - accuracy: 0.9905 - loss: 0.0328 - val_accuracy: 0.9351 - val_loss: 0.4630
Epoch 14/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 406ms/step - accuracy: 0.9916 - loss: 0.0253 - val_accuracy: 0.9313 - val_loss: 0.5127
Epoch 15/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 414ms/step - accuracy: 0.9857 - loss: 0.0457 - val_accuracy: 0.9160 - val_loss: 0.5916
Epoch 16/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 422ms/step - accuracy: 0.9826 - loss: 0.0566 - val_accuracy: 0.9389 - val_loss: 0.4681
Epoch 17/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 418ms/step - accuracy: 0.9882 - loss: 0.0367 - val_accuracy: 0.9237 - val_loss: 0.4785
Epoch 18/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 407ms/step - accuracy: 0.9899 - loss: 0.0242 - val_accuracy: 0.9046 - val_loss: 0.7800
Epoch 19/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 498ms/step - accuracy: 0.9911 - loss: 0.0288 - val_accuracy: 0.9427 - val_loss: 0.5952
Epoch 20/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 56s 549ms/step - accuracy: 0.9956 - loss: 0.0171 - val_accuracy: 0.9351 - val_loss: 0.6594
Epoch 21/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 522ms/step - accuracy: 0.9913 - loss: 0.0264 - val_accuracy: 0.9427 - val_loss: 0.6845
Epoch 22/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 56s 555ms/step - accuracy: 0.9953 - loss: 0.0183 - val_accuracy: 0.9427 - val_loss: 0.5893
Epoch 23/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 55s 535ms/step - accuracy: 0.9968 - loss: 0.0114 - val_accuracy: 0.9427 - val_loss: 0.6818
Epoch 24/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 512ms/step - accuracy: 0.9953 - loss: 0.0230 - val_accuracy: 0.9389 - val_loss: 0.6214
Epoch 25/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 54s 530ms/step - accuracy: 0.9949 - loss: 0.0188 - val_accuracy: 0.9351 - val_loss: 0.6380
Epoch 26/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 510ms/step - accuracy: 0.9944 - loss: 0.0229 - val_accuracy: 0.9389 - val_loss: 0.7436
Epoch 27/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 517ms/step - accuracy: 0.9942 - loss: 0.0166 - val_accuracy: 0.9389 - val_loss: 0.6644
Epoch 28/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 526ms/step - accuracy: 0.9872 - loss: 0.0349 - val_accuracy: 0.9275 - val_loss: 0.7268
Epoch 29/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 516ms/step - accuracy: 0.9936 - loss: 0.0243 - val_accuracy: 0.9466 - val_loss: 0.6768
Epoch 30/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 515ms/step - accuracy: 0.9888 - loss: 0.0291 - val_accuracy: 0.9313 - val_loss: 0.5553
Epoch 31/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 513ms/step - accuracy: 0.9836 - loss: 0.0569 - val_accuracy: 0.9313 - val_loss: 0.6402
Epoch 32/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 503ms/step - accuracy: 0.9914 - loss: 0.0280 - val_accuracy: 0.9504 - val_loss: 0.5774
Epoch 33/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 399ms/step - accuracy: 0.9852 - loss: 0.0441 - val_accuracy: 0.9389 - val_loss: 0.4621
Epoch 34/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 412ms/step - accuracy: 0.9868 - loss: 0.0392 - val_accuracy: 0.9466 - val_loss: 0.4100
Epoch 35/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 54s 534ms/step - accuracy: 0.9945 - loss: 0.0211 - val_accuracy: 0.9427 - val_loss: 0.4357
Epoch 36/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 60s 598ms/step - accuracy: 0.9942 - loss: 0.0218 - val_accuracy: 0.9427 - val_loss: 0.5096
Epoch 37/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 48s 474ms/step - accuracy: 0.9911 - loss: 0.0231 - val_accuracy: 0.9504 - val_loss: 0.4844
Epoch 38/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 56s 552ms/step - accuracy: 0.9934 - loss: 0.0187 - val_accuracy: 0.9389 - val_loss: 0.5852
Epoch 39/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 412ms/step - accuracy: 0.9952 - loss: 0.0111 - val_accuracy: 0.9466 - val_loss: 0.6731
Epoch 40/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 419ms/step - accuracy: 0.9937 - loss: 0.0232 - val_accuracy: 0.9351 - val_loss: 0.7436
Epoch 41/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 406ms/step - accuracy: 0.9859 - loss: 0.0451 - val_accuracy: 0.9580 - val_loss: 0.4898
Epoch 42/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 518ms/step - accuracy: 0.9888 - loss: 0.0285 - val_accuracy: 0.9351 - val_loss: 0.6509
Epoch 43/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 496ms/step - accuracy: 0.9893 - loss: 0.0317 - val_accuracy: 0.9313 - val_loss: 0.4865
Epoch 44/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 50s 490ms/step - accuracy: 0.9880 - loss: 0.0427 - val_accuracy: 0.9466 - val_loss: 0.5625
Epoch 45/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 506ms/step - accuracy: 0.9873 - loss: 0.0441 - val_accuracy: 0.9466 - val_loss: 0.4856
Epoch 46/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 520ms/step - accuracy: 0.9909 - loss: 0.0250 - val_accuracy: 0.9275 - val_loss: 0.6252
Epoch 47/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 510ms/step - accuracy: 0.9846 - loss: 0.0482 - val_accuracy: 0.9351 - val_loss: 0.5449
Epoch 48/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 50s 491ms/step - accuracy: 0.9914 - loss: 0.0350 - val_accuracy: 0.9313 - val_loss: 0.5463
Epoch 49/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 49s 485ms/step - accuracy: 0.9886 - loss: 0.0433 - val_accuracy: 0.9427 - val_loss: 0.6445
Epoch 50/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 46s 446ms/step - accuracy: 0.9814 - loss: 0.0612 - val_accuracy: 0.9237 - val_loss: 0.5988
In [41]:
# Save the model
model.save('new_cnn_model_3.keras')
In [42]:
import matplotlib.pyplot as plt

# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
In [43]:
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 99ms/step 
Val Accuracy = 0.9237
In [44]:
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 109ms/step
Test Accuracy = 0.9449
In [45]:
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.92      0.94       219
           1       0.95      0.93      0.94       187
           2       0.88      0.94      0.91        87
           3       0.95      1.00      0.97       160

    accuracy                           0.94       653
   macro avg       0.94      0.95      0.94       653
weighted avg       0.95      0.94      0.94       653

6.4 epoch = 35, steps_per_epoch=100¶
In [60]:
history = model.fit(X_train, y_train_new, 
                    batch_size=64, 
                    epochs=35, 
                    steps_per_epoch=100,
                    validation_data=(X_valid, y_valid_new))
Epoch 1/35
 37/100 ━━━━━━━━━━━━━━━━━━━━ 1:02 991ms/step - accuracy: 0.3305 - loss: 4.8506
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning:

Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.

100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 375ms/step - accuracy: 0.3932 - loss: 3.3018 - val_accuracy: 0.4466 - val_loss: 1.3470
Epoch 2/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 37s 368ms/step - accuracy: 0.5852 - loss: 0.9946 - val_accuracy: 0.3130 - val_loss: 1.3530
Epoch 3/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 371ms/step - accuracy: 0.6493 - loss: 0.8459 - val_accuracy: 0.3053 - val_loss: 1.4348
Epoch 4/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 369ms/step - accuracy: 0.6894 - loss: 0.7776 - val_accuracy: 0.3244 - val_loss: 1.5073
Epoch 5/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 373ms/step - accuracy: 0.7075 - loss: 0.7224 - val_accuracy: 0.3550 - val_loss: 1.4179
Epoch 6/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 378ms/step - accuracy: 0.7537 - loss: 0.6409 - val_accuracy: 0.3588 - val_loss: 1.4829
Epoch 7/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 368ms/step - accuracy: 0.7665 - loss: 0.5970 - val_accuracy: 0.4504 - val_loss: 1.2290
Epoch 8/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 374ms/step - accuracy: 0.8033 - loss: 0.5161 - val_accuracy: 0.4924 - val_loss: 1.2251
Epoch 9/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 371ms/step - accuracy: 0.8198 - loss: 0.4714 - val_accuracy: 0.5534 - val_loss: 1.1228
Epoch 10/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 372ms/step - accuracy: 0.8283 - loss: 0.4445 - val_accuracy: 0.7099 - val_loss: 0.7793
Epoch 11/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 45s 444ms/step - accuracy: 0.8432 - loss: 0.4166 - val_accuracy: 0.7519 - val_loss: 0.6910
Epoch 12/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 425ms/step - accuracy: 0.8540 - loss: 0.3607 - val_accuracy: 0.7366 - val_loss: 0.6631
Epoch 13/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 44s 432ms/step - accuracy: 0.8700 - loss: 0.3242 - val_accuracy: 0.7290 - val_loss: 0.7293
Epoch 14/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 410ms/step - accuracy: 0.8770 - loss: 0.3202 - val_accuracy: 0.8282 - val_loss: 0.4950
Epoch 15/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 47s 461ms/step - accuracy: 0.9050 - loss: 0.2347 - val_accuracy: 0.7710 - val_loss: 0.7423
Epoch 16/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 411ms/step - accuracy: 0.9131 - loss: 0.2261 - val_accuracy: 0.8473 - val_loss: 0.4228
Epoch 17/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 48s 474ms/step - accuracy: 0.9374 - loss: 0.1804 - val_accuracy: 0.8855 - val_loss: 0.3896
Epoch 18/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 46s 449ms/step - accuracy: 0.9257 - loss: 0.1934 - val_accuracy: 0.8893 - val_loss: 0.3702
Epoch 19/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 394ms/step - accuracy: 0.9385 - loss: 0.1609 - val_accuracy: 0.9046 - val_loss: 0.3787
Epoch 20/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 409ms/step - accuracy: 0.9522 - loss: 0.1312 - val_accuracy: 0.8969 - val_loss: 0.4051
Epoch 21/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 412ms/step - accuracy: 0.9461 - loss: 0.1377 - val_accuracy: 0.8435 - val_loss: 0.4281
Epoch 22/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 45s 441ms/step - accuracy: 0.9588 - loss: 0.1241 - val_accuracy: 0.8969 - val_loss: 0.3753
Epoch 23/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 391ms/step - accuracy: 0.9568 - loss: 0.1206 - val_accuracy: 0.9122 - val_loss: 0.3589
Epoch 24/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 383ms/step - accuracy: 0.9659 - loss: 0.0923 - val_accuracy: 0.9122 - val_loss: 0.3198
Epoch 25/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 401ms/step - accuracy: 0.9632 - loss: 0.1027 - val_accuracy: 0.8969 - val_loss: 0.3980
Epoch 26/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 372ms/step - accuracy: 0.9676 - loss: 0.0879 - val_accuracy: 0.8969 - val_loss: 0.4138
Epoch 27/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 372ms/step - accuracy: 0.9734 - loss: 0.0846 - val_accuracy: 0.8893 - val_loss: 0.4401
Epoch 28/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 381ms/step - accuracy: 0.9737 - loss: 0.0844 - val_accuracy: 0.9198 - val_loss: 0.3196
Epoch 29/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 370ms/step - accuracy: 0.9690 - loss: 0.0896 - val_accuracy: 0.8740 - val_loss: 0.4665
Epoch 30/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 376ms/step - accuracy: 0.9777 - loss: 0.0713 - val_accuracy: 0.9160 - val_loss: 0.3465
Epoch 31/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 378ms/step - accuracy: 0.9777 - loss: 0.0761 - val_accuracy: 0.9084 - val_loss: 0.3995
Epoch 32/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 382ms/step - accuracy: 0.9735 - loss: 0.0870 - val_accuracy: 0.9160 - val_loss: 0.3470
Epoch 33/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 374ms/step - accuracy: 0.9787 - loss: 0.0633 - val_accuracy: 0.9198 - val_loss: 0.3858
Epoch 34/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 377ms/step - accuracy: 0.9821 - loss: 0.0524 - val_accuracy: 0.9237 - val_loss: 0.3363
Epoch 35/35
100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 381ms/step - accuracy: 0.9829 - loss: 0.0523 - val_accuracy: 0.9237 - val_loss: 0.3193
In [61]:
# Save the model
model.save('new_cnn_model_6.keras')
In [62]:
import matplotlib.pyplot as plt

# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
In [63]:
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 127ms/step
Val Accuracy = 0.9237
In [64]:
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 3s 125ms/step
Test Accuracy = 0.9280
In [65]:
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
               precision    recall  f1-score   support

           0       0.90      0.93      0.92       219
           1       0.95      0.86      0.90       187
           2       0.87      0.94      0.91        87
           3       0.98      1.00      0.99       160

    accuracy                           0.93       653
   macro avg       0.92      0.93      0.93       653
weighted avg       0.93      0.93      0.93       653

6.5 epochs=100, steps_per_epoch=37 (early stopped at 46)¶
  1. Model Performance Overview

Validation Accuracy: The model achieves a validation accuracy of 92.75%.

Test Accuracy: The test accuracy is even higher at 90.96%.

Precision and Recall: All classes show strong precision (85-96%) and recall (80-100%). This indicates that the model is not only correctly identifying positive cases but is also precise in its predictions, minimizing false positives.

F1-Score: High F1-scores across all classes (89-97%) suggest a balanced performance between precision and recall, which is crucial for reliable classification.

  1. Training and Validation Curves

Accuracy Curve: The training accuracy plateaus close to 100%, while the validation accuracy stabilizes at a high level but with some gap compared to the training, indicating a slight overfitting but still within an acceptable range. Loss Curve: Training loss decreases sharply and flattens, which is ideal. However, the validation loss, despite decreasing, shows more fluctuations, which is typical but should be monitored to ensure it doesn't start to diverge from the training loss significantly.

  1. Callbacks and Adjustments

I've implemented useful callbacks like EarlyStopping and ReduceLROnPlateau, which are beneficial for handling overfitting and optimizing the training process:

EarlyStopping is configured to monitor the training loss, stopping the training if there are no improvements beyond a minimal delta, indicating that continuing training is inefficient.

ReduceLROnPlateau reduces the learning rate when the validation loss stops improving, helping the model to fine-tune adjustments in weights and potentially escape local minima.

The EarlyStopping and ReduceLROnPlateau are both callbacks in Keras that serve as training interventions to improve the training process and prevent overfitting. Each of these has specific roles and is used to monitor different aspects of the model during training. Let’s delve into the goals and functionalities of each:

EarlyStopping Goal: To halt the training process early if there is no significant improvement in a specified metric over a defined number of epochs. This is particularly useful in avoiding overfitting and unnecessarily long training times.

ReduceLROnPlateau Goal: To reduce the learning rate when a metric has stopped improving. This helps the model to fine-tune and potentially escape local minima during training. Lowering the learning rate can allow the model to make smaller changes to the weights and potentially discover better minima.

In [23]:
# Stop training if loss doesn't keep decreasing.
model_es = EarlyStopping(monitor='loss', min_delta=1e-9, patience=12, verbose=True)
model_rlr = ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=6, verbose=True)

history = model.fit(X_train, y_train_new, batch_size=64, epochs=100, validation_data=(X_valid, y_valid_new),
                   callbacks=[model_es, model_rlr])
Epoch 1/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3408 - loss: 4.1512 - val_accuracy: 0.1641 - val_loss: 1.3980 - learning_rate: 0.0010
Epoch 2/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 46s 1s/step - accuracy: 0.6022 - loss: 0.9985 - val_accuracy: 0.3397 - val_loss: 1.3390 - learning_rate: 0.0010
Epoch 3/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.6243 - loss: 0.8712 - val_accuracy: 0.3588 - val_loss: 1.3417 - learning_rate: 0.0010
Epoch 4/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.6738 - loss: 0.7678 - val_accuracy: 0.3626 - val_loss: 1.5721 - learning_rate: 0.0010
Epoch 5/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.7683 - loss: 0.5974 - val_accuracy: 0.3588 - val_loss: 1.3576 - learning_rate: 0.0010
Epoch 6/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.7721 - loss: 0.5671 - val_accuracy: 0.4924 - val_loss: 1.1619 - learning_rate: 0.0010
Epoch 7/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.8083 - loss: 0.4945 - val_accuracy: 0.6565 - val_loss: 0.8985 - learning_rate: 0.0010
Epoch 8/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.8143 - loss: 0.4453 - val_accuracy: 0.6603 - val_loss: 0.8329 - learning_rate: 0.0010
Epoch 9/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.8614 - loss: 0.3480 - val_accuracy: 0.7214 - val_loss: 0.7044 - learning_rate: 0.0010
Epoch 10/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.8816 - loss: 0.3137 - val_accuracy: 0.7481 - val_loss: 0.6579 - learning_rate: 0.0010
Epoch 11/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.8578 - loss: 0.3504 - val_accuracy: 0.8092 - val_loss: 0.5405 - learning_rate: 0.0010
Epoch 12/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.9146 - loss: 0.2361 - val_accuracy: 0.8206 - val_loss: 0.5122 - learning_rate: 0.0010
Epoch 13/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9134 - loss: 0.2439 - val_accuracy: 0.8969 - val_loss: 0.4035 - learning_rate: 0.0010
Epoch 14/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 42s 1s/step - accuracy: 0.9283 - loss: 0.1944 - val_accuracy: 0.8931 - val_loss: 0.3910 - learning_rate: 0.0010
Epoch 15/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 45s 1s/step - accuracy: 0.9340 - loss: 0.2014 - val_accuracy: 0.8931 - val_loss: 0.3655 - learning_rate: 0.0010
Epoch 16/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9562 - loss: 0.1278 - val_accuracy: 0.8893 - val_loss: 0.4169 - learning_rate: 0.0010
Epoch 17/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9317 - loss: 0.1774 - val_accuracy: 0.9160 - val_loss: 0.3436 - learning_rate: 0.0010
Epoch 18/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.9531 - loss: 0.1303 - val_accuracy: 0.8893 - val_loss: 0.4071 - learning_rate: 0.0010
Epoch 19/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 42s 1s/step - accuracy: 0.9679 - loss: 0.1018 - val_accuracy: 0.9237 - val_loss: 0.3675 - learning_rate: 0.0010
Epoch 20/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 43s 1s/step - accuracy: 0.9721 - loss: 0.0828 - val_accuracy: 0.9237 - val_loss: 0.3545 - learning_rate: 0.0010
Epoch 21/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9722 - loss: 0.0815 - val_accuracy: 0.8550 - val_loss: 0.4985 - learning_rate: 0.0010
Epoch 22/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.9559 - loss: 0.1091 - val_accuracy: 0.9008 - val_loss: 0.4358 - learning_rate: 0.0010
Epoch 23/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9783 - loss: 0.0598
Epoch 23: ReduceLROnPlateau reducing learning rate to 0.0003000000142492354.
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9783 - loss: 0.0600 - val_accuracy: 0.8931 - val_loss: 0.4360 - learning_rate: 0.0010
Epoch 24/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9809 - loss: 0.0616 - val_accuracy: 0.9237 - val_loss: 0.4380 - learning_rate: 3.0000e-04
Epoch 25/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9850 - loss: 0.0484 - val_accuracy: 0.9198 - val_loss: 0.4186 - learning_rate: 3.0000e-04
Epoch 26/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9890 - loss: 0.0320 - val_accuracy: 0.9313 - val_loss: 0.4050 - learning_rate: 3.0000e-04
Epoch 27/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9886 - loss: 0.0322 - val_accuracy: 0.9275 - val_loss: 0.3922 - learning_rate: 3.0000e-04
Epoch 28/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9924 - loss: 0.0293 - val_accuracy: 0.9275 - val_loss: 0.4017 - learning_rate: 3.0000e-04
Epoch 29/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9883 - loss: 0.0285
Epoch 29: ReduceLROnPlateau reducing learning rate to 9.000000427477062e-05.
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9883 - loss: 0.0286 - val_accuracy: 0.9237 - val_loss: 0.4063 - learning_rate: 3.0000e-04
Epoch 30/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9908 - loss: 0.0232 - val_accuracy: 0.9275 - val_loss: 0.4082 - learning_rate: 9.0000e-05
Epoch 31/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 41s 1s/step - accuracy: 0.9893 - loss: 0.0254 - val_accuracy: 0.9198 - val_loss: 0.4178 - learning_rate: 9.0000e-05
Epoch 32/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9930 - loss: 0.0190 - val_accuracy: 0.9237 - val_loss: 0.4098 - learning_rate: 9.0000e-05
Epoch 33/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 41s 1s/step - accuracy: 0.9945 - loss: 0.0169 - val_accuracy: 0.9275 - val_loss: 0.4167 - learning_rate: 9.0000e-05
Epoch 34/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9964 - loss: 0.0160 - val_accuracy: 0.9237 - val_loss: 0.4290 - learning_rate: 9.0000e-05
Epoch 35/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9949 - loss: 0.0200
Epoch 35: ReduceLROnPlateau reducing learning rate to 2.700000040931627e-05.
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9949 - loss: 0.0200 - val_accuracy: 0.9237 - val_loss: 0.4500 - learning_rate: 9.0000e-05
Epoch 36/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9931 - loss: 0.0202 - val_accuracy: 0.9198 - val_loss: 0.4375 - learning_rate: 2.7000e-05
Epoch 37/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9936 - loss: 0.0185 - val_accuracy: 0.9198 - val_loss: 0.4368 - learning_rate: 2.7000e-05
Epoch 38/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 47s 1s/step - accuracy: 0.9951 - loss: 0.0168 - val_accuracy: 0.9237 - val_loss: 0.4384 - learning_rate: 2.7000e-05
Epoch 39/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 44s 1s/step - accuracy: 0.9930 - loss: 0.0242 - val_accuracy: 0.9237 - val_loss: 0.4341 - learning_rate: 2.7000e-05
Epoch 40/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9914 - loss: 0.0349 - val_accuracy: 0.9275 - val_loss: 0.4438 - learning_rate: 2.7000e-05
Epoch 41/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9928 - loss: 0.0206
Epoch 41: ReduceLROnPlateau reducing learning rate to 8.100000013655517e-06.
37/37 ━━━━━━━━━━━━━━━━━━━━ 47s 1s/step - accuracy: 0.9928 - loss: 0.0207 - val_accuracy: 0.9275 - val_loss: 0.4437 - learning_rate: 2.7000e-05
Epoch 42/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 46s 1s/step - accuracy: 0.9938 - loss: 0.0207 - val_accuracy: 0.9275 - val_loss: 0.4443 - learning_rate: 8.1000e-06
Epoch 43/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9900 - loss: 0.0334 - val_accuracy: 0.9275 - val_loss: 0.4419 - learning_rate: 8.1000e-06
Epoch 44/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 46s 1s/step - accuracy: 0.9928 - loss: 0.0202 - val_accuracy: 0.9275 - val_loss: 0.4428 - learning_rate: 8.1000e-06
Epoch 45/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 42s 1s/step - accuracy: 0.9939 - loss: 0.0176 - val_accuracy: 0.9275 - val_loss: 0.4450 - learning_rate: 8.1000e-06
Epoch 46/100
37/37 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.9941 - loss: 0.0169 - val_accuracy: 0.9275 - val_loss: 0.4445 - learning_rate: 8.1000e-06
Epoch 46: early stopping
In [24]:
# Save the model
# this is baseline model with rotation range = 20
model.save('new_cnn_model1.keras')
In [25]:
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 125ms/step
Val Accuracy = 0.9275
In [27]:
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)

# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 106ms/step
Test Accuracy = 0.9096
In [28]:
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.93      0.89       198
           1       0.91      0.80      0.85       183
           2       0.96      0.90      0.93       104
           3       0.95      1.00      0.97       168

    accuracy                           0.91       653
   macro avg       0.92      0.91      0.91       653
weighted avg       0.91      0.91      0.91       653

Analysis:

High Specificity in Some Classes: The model is highly specific in recognizing pituitary tumors and generally good at identifying glioma tumors.

Challenges with Meningioma: There seems to be some confusion between meningioma and glioma tumors, which might require further investigation. Feature similarities between these types could be causing the model to struggle in differentiating them accurately.

Potential for Serious Misclassification: The misclassification between tumorous and non-tumorous scans, although low, is a critical error and should be minimized as much as possible.

In [29]:
labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

# Define the custom color map
custom_colors = ['#01411C','#4B6F44','#4F7942','#74C365','#D0F0C0']
custom_cmap = matplotlib.colors.ListedColormap(custom_colors)

# Calculate confusion matrix
confusion_matrix = confusion_matrix(Y_test, y_pred)

# Create a display object with the custom color map
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'])

# Plot the confusion matrix
fig, ax = plt.subplots()
disp.plot(cmap=custom_cmap, ax=ax)

# Set the title and axis labels
fig.text(s='Heatmap of the Confusion Matrix',size=18,fontweight='bold',
             fontname='monospace',color=colors_dark[1],y=0.92,x=0.10,alpha=0.8)

# Rotate x-axis labels
plt.xticks(rotation=45)

# Save the figure
plt.savefig('CM CNN-2.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()
In [30]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5), facecolor='white')

# Plot training and validation accuracy
ax[0].plot(history.history['accuracy'])
ax[0].plot(history.history['val_accuracy'])
ax[0].set_title('Model Accuracy')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Accuracy')
ax[0].legend(['Train', 'Validation'], loc='upper left')

# Plot training and validation loss
ax[1].plot(history.history['loss'])
ax[1].plot(history.history['val_loss'])
ax[1].set_title('Model Loss')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].legend(['Train', 'Validation'], loc='upper right')

# Save the figure
plt.savefig('plot CNN-2.png', dpi=300, bbox_inches='tight')

plt.tight_layout()
plt.show()
In [31]:
from sklearn.metrics import roc_curve, roc_auc_score
import numpy as np

# Compute predicted probabilities for each class
y_probs = model.predict(X_test)

# Ensure that the target labels Y_test are in a 2-dimensional format
if len(Y_test.shape) == 1:
    Y_test = np.eye(len(np.unique(Y_test)))[Y_test.astype(int)]

# Compute the ROC curve and AUC score for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(Y_test.shape[1]):
    fpr[i], tpr[i], _ = roc_curve(Y_test[:, i], y_probs[:, i])
    roc_auc[i] = roc_auc_score(Y_test[:, i], y_probs[:, i])

# Plot the ROC curve for each class
plt.figure()
for i in range(Y_test.shape[1]):
    plt.plot(fpr[i], tpr[i], label=f'Class {i} (AUC = {roc_auc[i]:.2f})')

# Set the title and axis labels
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')

# Save the figure
plt.savefig('ROC CNN-2.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 107ms/step
Wrong prediction example¶
In [32]:
class_labels=['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
plt.figure(figsize=(16,20))

for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(X_test[i])
    actual_label_idx = np.argmax(Y_test[i])  # Assuming Y_test is one-hot encoded
    predicted_label_idx = np.argmax(y_pred[i])  # Assuming y_pred is one-hot encoded
    plt.title(f"Actual label:{class_labels[actual_label_idx]}\nPredicted label:{class_labels[predicted_label_idx]}")
    plt.axis("off")
In [ ]:
 
In [ ]:
 

Step 7 TL with EfficientNetB0¶

7.1 Model 1¶
In [25]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

num_classes = len(labels)

# Define EfficientNet model
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()
Model: "functional_29"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer_4       │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ rescaling_4         │ (None, 224, 224,  │          0 │ input_layer_4[0]… │
│ (Rescaling)         │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_2     │ (None, 224, 224,  │          7 │ rescaling_4[0][0] │
│ (Normalization)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ rescaling_5         │ (None, 224, 224,  │          0 │ normalization_2[… │
│ (Rescaling)         │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv_pad       │ (None, 225, 225,  │          0 │ rescaling_5[0][0] │
│ (ZeroPadding2D)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv2D)  │ (None, 112, 112,  │        864 │ stem_conv_pad[0]… │
│                     │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_bn             │ (None, 112, 112,  │        128 │ stem_conv[0][0]   │
│ (BatchNormalizatio… │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_activation     │ (None, 112, 112,  │          0 │ stem_bn[0][0]     │
│ (Activation)        │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_dwconv      │ (None, 112, 112,  │        288 │ stem_activation[… │
│ (DepthwiseConv2D)   │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_bn          │ (None, 112, 112,  │        128 │ block1a_dwconv[0… │
│ (BatchNormalizatio… │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_activation  │ (None, 112, 112,  │          0 │ block1a_bn[0][0]  │
│ (Activation)        │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_squeeze  │ (None, 32)        │          0 │ block1a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_reshape  │ (None, 1, 1, 32)  │          0 │ block1a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_reduce   │ (None, 1, 1, 8)   │        264 │ block1a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_expand   │ (None, 1, 1, 32)  │        288 │ block1a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_excite   │ (None, 112, 112,  │          0 │ block1a_activati… │
│ (Multiply)          │ 32)               │            │ block1a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_project_co… │ (None, 112, 112,  │        512 │ block1a_se_excit… │
│ (Conv2D)            │ 16)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_project_bn  │ (None, 112, 112,  │         64 │ block1a_project_… │
│ (BatchNormalizatio… │ 16)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_expand_conv │ (None, 112, 112,  │      1,536 │ block1a_project_… │
│ (Conv2D)            │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_expand_bn   │ (None, 112, 112,  │        384 │ block2a_expand_c… │
│ (BatchNormalizatio… │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_expand_act… │ (None, 112, 112,  │          0 │ block2a_expand_b… │
│ (Activation)        │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_dwconv_pad  │ (None, 113, 113,  │          0 │ block2a_expand_a… │
│ (ZeroPadding2D)     │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_dwconv      │ (None, 56, 56,    │        864 │ block2a_dwconv_p… │
│ (DepthwiseConv2D)   │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_bn          │ (None, 56, 56,    │        384 │ block2a_dwconv[0… │
│ (BatchNormalizatio… │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_activation  │ (None, 56, 56,    │          0 │ block2a_bn[0][0]  │
│ (Activation)        │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_squeeze  │ (None, 96)        │          0 │ block2a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_reshape  │ (None, 1, 1, 96)  │          0 │ block2a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_reduce   │ (None, 1, 1, 4)   │        388 │ block2a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_expand   │ (None, 1, 1, 96)  │        480 │ block2a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_excite   │ (None, 56, 56,    │          0 │ block2a_activati… │
│ (Multiply)          │ 96)               │            │ block2a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_project_co… │ (None, 56, 56,    │      2,304 │ block2a_se_excit… │
│ (Conv2D)            │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_project_bn  │ (None, 56, 56,    │         96 │ block2a_project_… │
│ (BatchNormalizatio… │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_expand_conv │ (None, 56, 56,    │      3,456 │ block2a_project_… │
│ (Conv2D)            │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_expand_bn   │ (None, 56, 56,    │        576 │ block2b_expand_c… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_expand_act… │ (None, 56, 56,    │          0 │ block2b_expand_b… │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_dwconv      │ (None, 56, 56,    │      1,296 │ block2b_expand_a… │
│ (DepthwiseConv2D)   │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_bn          │ (None, 56, 56,    │        576 │ block2b_dwconv[0… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_activation  │ (None, 56, 56,    │          0 │ block2b_bn[0][0]  │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_squeeze  │ (None, 144)       │          0 │ block2b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_reshape  │ (None, 1, 1, 144) │          0 │ block2b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_reduce   │ (None, 1, 1, 6)   │        870 │ block2b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_expand   │ (None, 1, 1, 144) │      1,008 │ block2b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_excite   │ (None, 56, 56,    │          0 │ block2b_activati… │
│ (Multiply)          │ 144)              │            │ block2b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_project_co… │ (None, 56, 56,    │      3,456 │ block2b_se_excit… │
│ (Conv2D)            │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_project_bn  │ (None, 56, 56,    │         96 │ block2b_project_… │
│ (BatchNormalizatio… │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_drop        │ (None, 56, 56,    │          0 │ block2b_project_… │
│ (Dropout)           │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_add (Add)   │ (None, 56, 56,    │          0 │ block2b_drop[0][… │
│                     │ 24)               │            │ block2a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_expand_conv │ (None, 56, 56,    │      3,456 │ block2b_add[0][0] │
│ (Conv2D)            │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_expand_bn   │ (None, 56, 56,    │        576 │ block3a_expand_c… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_expand_act… │ (None, 56, 56,    │          0 │ block3a_expand_b… │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_dwconv_pad  │ (None, 59, 59,    │          0 │ block3a_expand_a… │
│ (ZeroPadding2D)     │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_dwconv      │ (None, 28, 28,    │      3,600 │ block3a_dwconv_p… │
│ (DepthwiseConv2D)   │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_bn          │ (None, 28, 28,    │        576 │ block3a_dwconv[0… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_activation  │ (None, 28, 28,    │          0 │ block3a_bn[0][0]  │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_squeeze  │ (None, 144)       │          0 │ block3a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_reshape  │ (None, 1, 1, 144) │          0 │ block3a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_reduce   │ (None, 1, 1, 6)   │        870 │ block3a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_expand   │ (None, 1, 1, 144) │      1,008 │ block3a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_excite   │ (None, 28, 28,    │          0 │ block3a_activati… │
│ (Multiply)          │ 144)              │            │ block3a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_project_co… │ (None, 28, 28,    │      5,760 │ block3a_se_excit… │
│ (Conv2D)            │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_project_bn  │ (None, 28, 28,    │        160 │ block3a_project_… │
│ (BatchNormalizatio… │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_expand_conv │ (None, 28, 28,    │      9,600 │ block3a_project_… │
│ (Conv2D)            │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_expand_bn   │ (None, 28, 28,    │        960 │ block3b_expand_c… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_expand_act… │ (None, 28, 28,    │          0 │ block3b_expand_b… │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_dwconv      │ (None, 28, 28,    │      6,000 │ block3b_expand_a… │
│ (DepthwiseConv2D)   │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_bn          │ (None, 28, 28,    │        960 │ block3b_dwconv[0… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_activation  │ (None, 28, 28,    │          0 │ block3b_bn[0][0]  │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_squeeze  │ (None, 240)       │          0 │ block3b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_reshape  │ (None, 1, 1, 240) │          0 │ block3b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_reduce   │ (None, 1, 1, 10)  │      2,410 │ block3b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_expand   │ (None, 1, 1, 240) │      2,640 │ block3b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_excite   │ (None, 28, 28,    │          0 │ block3b_activati… │
│ (Multiply)          │ 240)              │            │ block3b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_project_co… │ (None, 28, 28,    │      9,600 │ block3b_se_excit… │
│ (Conv2D)            │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_project_bn  │ (None, 28, 28,    │        160 │ block3b_project_… │
│ (BatchNormalizatio… │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_drop        │ (None, 28, 28,    │          0 │ block3b_project_… │
│ (Dropout)           │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_add (Add)   │ (None, 28, 28,    │          0 │ block3b_drop[0][… │
│                     │ 40)               │            │ block3a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_expand_conv │ (None, 28, 28,    │      9,600 │ block3b_add[0][0] │
│ (Conv2D)            │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_expand_bn   │ (None, 28, 28,    │        960 │ block4a_expand_c… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_expand_act… │ (None, 28, 28,    │          0 │ block4a_expand_b… │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_dwconv_pad  │ (None, 29, 29,    │          0 │ block4a_expand_a… │
│ (ZeroPadding2D)     │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_dwconv      │ (None, 14, 14,    │      2,160 │ block4a_dwconv_p… │
│ (DepthwiseConv2D)   │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_bn          │ (None, 14, 14,    │        960 │ block4a_dwconv[0… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_activation  │ (None, 14, 14,    │          0 │ block4a_bn[0][0]  │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_squeeze  │ (None, 240)       │          0 │ block4a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_reshape  │ (None, 1, 1, 240) │          0 │ block4a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_reduce   │ (None, 1, 1, 10)  │      2,410 │ block4a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_expand   │ (None, 1, 1, 240) │      2,640 │ block4a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_excite   │ (None, 14, 14,    │          0 │ block4a_activati… │
│ (Multiply)          │ 240)              │            │ block4a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_project_co… │ (None, 14, 14,    │     19,200 │ block4a_se_excit… │
│ (Conv2D)            │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_project_bn  │ (None, 14, 14,    │        320 │ block4a_project_… │
│ (BatchNormalizatio… │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_expand_conv │ (None, 14, 14,    │     38,400 │ block4a_project_… │
│ (Conv2D)            │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_expand_bn   │ (None, 14, 14,    │      1,920 │ block4b_expand_c… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_expand_act… │ (None, 14, 14,    │          0 │ block4b_expand_b… │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_dwconv      │ (None, 14, 14,    │      4,320 │ block4b_expand_a… │
│ (DepthwiseConv2D)   │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_bn          │ (None, 14, 14,    │      1,920 │ block4b_dwconv[0… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_activation  │ (None, 14, 14,    │          0 │ block4b_bn[0][0]  │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_squeeze  │ (None, 480)       │          0 │ block4b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_reshape  │ (None, 1, 1, 480) │          0 │ block4b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_reduce   │ (None, 1, 1, 20)  │      9,620 │ block4b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_expand   │ (None, 1, 1, 480) │     10,080 │ block4b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_excite   │ (None, 14, 14,    │          0 │ block4b_activati… │
│ (Multiply)          │ 480)              │            │ block4b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_project_co… │ (None, 14, 14,    │     38,400 │ block4b_se_excit… │
│ (Conv2D)            │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_project_bn  │ (None, 14, 14,    │        320 │ block4b_project_… │
│ (BatchNormalizatio… │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_drop        │ (None, 14, 14,    │          0 │ block4b_project_… │
│ (Dropout)           │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_add (Add)   │ (None, 14, 14,    │          0 │ block4b_drop[0][… │
│                     │ 80)               │            │ block4a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_expand_conv │ (None, 14, 14,    │     38,400 │ block4b_add[0][0] │
│ (Conv2D)            │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_expand_bn   │ (None, 14, 14,    │      1,920 │ block4c_expand_c… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_expand_act… │ (None, 14, 14,    │          0 │ block4c_expand_b… │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_dwconv      │ (None, 14, 14,    │      4,320 │ block4c_expand_a… │
│ (DepthwiseConv2D)   │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_bn          │ (None, 14, 14,    │      1,920 │ block4c_dwconv[0… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_activation  │ (None, 14, 14,    │          0 │ block4c_bn[0][0]  │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_squeeze  │ (None, 480)       │          0 │ block4c_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_reshape  │ (None, 1, 1, 480) │          0 │ block4c_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_reduce   │ (None, 1, 1, 20)  │      9,620 │ block4c_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_expand   │ (None, 1, 1, 480) │     10,080 │ block4c_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_excite   │ (None, 14, 14,    │          0 │ block4c_activati… │
│ (Multiply)          │ 480)              │            │ block4c_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_project_co… │ (None, 14, 14,    │     38,400 │ block4c_se_excit… │
│ (Conv2D)            │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_project_bn  │ (None, 14, 14,    │        320 │ block4c_project_… │
│ (BatchNormalizatio… │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_drop        │ (None, 14, 14,    │          0 │ block4c_project_… │
│ (Dropout)           │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_add (Add)   │ (None, 14, 14,    │          0 │ block4c_drop[0][… │
│                     │ 80)               │            │ block4b_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_expand_conv │ (None, 14, 14,    │     38,400 │ block4c_add[0][0] │
│ (Conv2D)            │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_expand_bn   │ (None, 14, 14,    │      1,920 │ block5a_expand_c… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_expand_act… │ (None, 14, 14,    │          0 │ block5a_expand_b… │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_dwconv      │ (None, 14, 14,    │     12,000 │ block5a_expand_a… │
│ (DepthwiseConv2D)   │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_bn          │ (None, 14, 14,    │      1,920 │ block5a_dwconv[0… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_activation  │ (None, 14, 14,    │          0 │ block5a_bn[0][0]  │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_squeeze  │ (None, 480)       │          0 │ block5a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_reshape  │ (None, 1, 1, 480) │          0 │ block5a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_reduce   │ (None, 1, 1, 20)  │      9,620 │ block5a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_expand   │ (None, 1, 1, 480) │     10,080 │ block5a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_excite   │ (None, 14, 14,    │          0 │ block5a_activati… │
│ (Multiply)          │ 480)              │            │ block5a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_project_co… │ (None, 14, 14,    │     53,760 │ block5a_se_excit… │
│ (Conv2D)            │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_project_bn  │ (None, 14, 14,    │        448 │ block5a_project_… │
│ (BatchNormalizatio… │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_expand_conv │ (None, 14, 14,    │     75,264 │ block5a_project_… │
│ (Conv2D)            │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_expand_bn   │ (None, 14, 14,    │      2,688 │ block5b_expand_c… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_expand_act… │ (None, 14, 14,    │          0 │ block5b_expand_b… │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_dwconv      │ (None, 14, 14,    │     16,800 │ block5b_expand_a… │
│ (DepthwiseConv2D)   │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_bn          │ (None, 14, 14,    │      2,688 │ block5b_dwconv[0… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_activation  │ (None, 14, 14,    │          0 │ block5b_bn[0][0]  │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_squeeze  │ (None, 672)       │          0 │ block5b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_reshape  │ (None, 1, 1, 672) │          0 │ block5b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_reduce   │ (None, 1, 1, 28)  │     18,844 │ block5b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_expand   │ (None, 1, 1, 672) │     19,488 │ block5b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_excite   │ (None, 14, 14,    │          0 │ block5b_activati… │
│ (Multiply)          │ 672)              │            │ block5b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_project_co… │ (None, 14, 14,    │     75,264 │ block5b_se_excit… │
│ (Conv2D)            │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_project_bn  │ (None, 14, 14,    │        448 │ block5b_project_… │
│ (BatchNormalizatio… │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_drop        │ (None, 14, 14,    │          0 │ block5b_project_… │
│ (Dropout)           │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_add (Add)   │ (None, 14, 14,    │          0 │ block5b_drop[0][… │
│                     │ 112)              │            │ block5a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_expand_conv │ (None, 14, 14,    │     75,264 │ block5b_add[0][0] │
│ (Conv2D)            │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_expand_bn   │ (None, 14, 14,    │      2,688 │ block5c_expand_c… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_expand_act… │ (None, 14, 14,    │          0 │ block5c_expand_b… │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_dwconv      │ (None, 14, 14,    │     16,800 │ block5c_expand_a… │
│ (DepthwiseConv2D)   │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_bn          │ (None, 14, 14,    │      2,688 │ block5c_dwconv[0… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_activation  │ (None, 14, 14,    │          0 │ block5c_bn[0][0]  │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_squeeze  │ (None, 672)       │          0 │ block5c_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_reshape  │ (None, 1, 1, 672) │          0 │ block5c_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_reduce   │ (None, 1, 1, 28)  │     18,844 │ block5c_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_expand   │ (None, 1, 1, 672) │     19,488 │ block5c_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_excite   │ (None, 14, 14,    │          0 │ block5c_activati… │
│ (Multiply)          │ 672)              │            │ block5c_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_project_co… │ (None, 14, 14,    │     75,264 │ block5c_se_excit… │
│ (Conv2D)            │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_project_bn  │ (None, 14, 14,    │        448 │ block5c_project_… │
│ (BatchNormalizatio… │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_drop        │ (None, 14, 14,    │          0 │ block5c_project_… │
│ (Dropout)           │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_add (Add)   │ (None, 14, 14,    │          0 │ block5c_drop[0][… │
│                     │ 112)              │            │ block5b_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_expand_conv │ (None, 14, 14,    │     75,264 │ block5c_add[0][0] │
│ (Conv2D)            │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_expand_bn   │ (None, 14, 14,    │      2,688 │ block6a_expand_c… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_expand_act… │ (None, 14, 14,    │          0 │ block6a_expand_b… │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_dwconv_pad  │ (None, 17, 17,    │          0 │ block6a_expand_a… │
│ (ZeroPadding2D)     │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_dwconv      │ (None, 7, 7, 672) │     16,800 │ block6a_dwconv_p… │
│ (DepthwiseConv2D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_bn          │ (None, 7, 7, 672) │      2,688 │ block6a_dwconv[0… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_activation  │ (None, 7, 7, 672) │          0 │ block6a_bn[0][0]  │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_squeeze  │ (None, 672)       │          0 │ block6a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_reshape  │ (None, 1, 1, 672) │          0 │ block6a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_reduce   │ (None, 1, 1, 28)  │     18,844 │ block6a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_expand   │ (None, 1, 1, 672) │     19,488 │ block6a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_excite   │ (None, 7, 7, 672) │          0 │ block6a_activati… │
│ (Multiply)          │                   │            │ block6a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_project_co… │ (None, 7, 7, 192) │    129,024 │ block6a_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_project_bn  │ (None, 7, 7, 192) │        768 │ block6a_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_expand_conv │ (None, 7, 7,      │    221,184 │ block6a_project_… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_expand_bn   │ (None, 7, 7,      │      4,608 │ block6b_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_expand_act… │ (None, 7, 7,      │          0 │ block6b_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_dwconv      │ (None, 7, 7,      │     28,800 │ block6b_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_bn          │ (None, 7, 7,      │      4,608 │ block6b_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_activation  │ (None, 7, 7,      │          0 │ block6b_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_squeeze  │ (None, 1152)      │          0 │ block6b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_reshape  │ (None, 1, 1,      │          0 │ block6b_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block6b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_expand   │ (None, 1, 1,      │     56,448 │ block6b_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_excite   │ (None, 7, 7,      │          0 │ block6b_activati… │
│ (Multiply)          │ 1152)             │            │ block6b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_project_co… │ (None, 7, 7, 192) │    221,184 │ block6b_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_project_bn  │ (None, 7, 7, 192) │        768 │ block6b_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_drop        │ (None, 7, 7, 192) │          0 │ block6b_project_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_add (Add)   │ (None, 7, 7, 192) │          0 │ block6b_drop[0][… │
│                     │                   │            │ block6a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_expand_conv │ (None, 7, 7,      │    221,184 │ block6b_add[0][0] │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_expand_bn   │ (None, 7, 7,      │      4,608 │ block6c_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_expand_act… │ (None, 7, 7,      │          0 │ block6c_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_dwconv      │ (None, 7, 7,      │     28,800 │ block6c_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_bn          │ (None, 7, 7,      │      4,608 │ block6c_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_activation  │ (None, 7, 7,      │          0 │ block6c_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_squeeze  │ (None, 1152)      │          0 │ block6c_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_reshape  │ (None, 1, 1,      │          0 │ block6c_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block6c_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_expand   │ (None, 1, 1,      │     56,448 │ block6c_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_excite   │ (None, 7, 7,      │          0 │ block6c_activati… │
│ (Multiply)          │ 1152)             │            │ block6c_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_project_co… │ (None, 7, 7, 192) │    221,184 │ block6c_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_project_bn  │ (None, 7, 7, 192) │        768 │ block6c_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_drop        │ (None, 7, 7, 192) │          0 │ block6c_project_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_add (Add)   │ (None, 7, 7, 192) │          0 │ block6c_drop[0][… │
│                     │                   │            │ block6b_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_expand_conv │ (None, 7, 7,      │    221,184 │ block6c_add[0][0] │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_expand_bn   │ (None, 7, 7,      │      4,608 │ block6d_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_expand_act… │ (None, 7, 7,      │          0 │ block6d_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_dwconv      │ (None, 7, 7,      │     28,800 │ block6d_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_bn          │ (None, 7, 7,      │      4,608 │ block6d_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_activation  │ (None, 7, 7,      │          0 │ block6d_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_squeeze  │ (None, 1152)      │          0 │ block6d_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_reshape  │ (None, 1, 1,      │          0 │ block6d_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block6d_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_expand   │ (None, 1, 1,      │     56,448 │ block6d_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_excite   │ (None, 7, 7,      │          0 │ block6d_activati… │
│ (Multiply)          │ 1152)             │            │ block6d_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_project_co… │ (None, 7, 7, 192) │    221,184 │ block6d_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_project_bn  │ (None, 7, 7, 192) │        768 │ block6d_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_drop        │ (None, 7, 7, 192) │          0 │ block6d_project_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_add (Add)   │ (None, 7, 7, 192) │          0 │ block6d_drop[0][… │
│                     │                   │            │ block6c_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_expand_conv │ (None, 7, 7,      │    221,184 │ block6d_add[0][0] │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_expand_bn   │ (None, 7, 7,      │      4,608 │ block7a_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_expand_act… │ (None, 7, 7,      │          0 │ block7a_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_dwconv      │ (None, 7, 7,      │     10,368 │ block7a_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_bn          │ (None, 7, 7,      │      4,608 │ block7a_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_activation  │ (None, 7, 7,      │          0 │ block7a_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_squeeze  │ (None, 1152)      │          0 │ block7a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_reshape  │ (None, 1, 1,      │          0 │ block7a_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block7a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_expand   │ (None, 1, 1,      │     56,448 │ block7a_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_excite   │ (None, 7, 7,      │          0 │ block7a_activati… │
│ (Multiply)          │ 1152)             │            │ block7a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_project_co… │ (None, 7, 7, 320) │    368,640 │ block7a_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_project_bn  │ (None, 7, 7, 320) │      1,280 │ block7a_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ top_conv (Conv2D)   │ (None, 7, 7,      │    409,600 │ block7a_project_… │
│                     │ 1280)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ top_bn              │ (None, 7, 7,      │      5,120 │ top_conv[0][0]    │
│ (BatchNormalizatio… │ 1280)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ top_activation      │ (None, 7, 7,      │          0 │ top_bn[0][0]      │
│ (Activation)        │ 1280)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 1280)      │          0 │ top_activation[0… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_7 (Dense)     │ (None, 512)       │    655,872 │ global_average_p… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_8 (Dense)     │ (None, 4)         │      2,052 │ dense_7[0][0]     │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 4,707,495 (17.96 MB)
 Trainable params: 4,665,472 (17.80 MB)
 Non-trainable params: 42,023 (164.16 KB)
In [26]:
# Train the model
history = model.fit(train_generator,
                    epochs=50,
                    validation_data=val_generator,
                    steps_per_epoch=len(X_train) // 64,
                    validation_steps=len(X_valid) // 64)
Epoch 1/50
C:\Users\yanch\anaconda3\Lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py:120: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()
36/36 ━━━━━━━━━━━━━━━━━━━━ 158s 3s/step - accuracy: 0.6149 - loss: 0.8603 - val_accuracy: 0.1016 - val_loss: 8.2727
Epoch 2/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.8646 - loss: 0.3835 - val_accuracy: 0.1484 - val_loss: 3.3987
Epoch 3/50
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
  self.gen.throw(typ, value, traceback)
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 66ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1797 - val_loss: 3.4343
Epoch 4/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9037 - loss: 0.2730 - val_accuracy: 0.1562 - val_loss: 3.9329
Epoch 5/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.8994 - loss: 0.2943 - val_accuracy: 0.1967 - val_loss: 5.9635
Epoch 6/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 67ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1328 - val_loss: 6.3310
Epoch 7/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 107s 3s/step - accuracy: 0.8992 - loss: 0.2623 - val_accuracy: 0.1641 - val_loss: 6.7574
Epoch 8/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 99s 3s/step - accuracy: 0.9431 - loss: 0.1577 - val_accuracy: 0.1797 - val_loss: 2.4378
Epoch 9/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 62ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1484 - val_loss: 2.4938
Epoch 10/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9173 - loss: 0.2078 - val_accuracy: 0.1148 - val_loss: 7.2145
Epoch 11/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 100s 3s/step - accuracy: 0.9472 - loss: 0.1509 - val_accuracy: 0.1641 - val_loss: 4.5552
Epoch 12/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 63ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1484 - val_loss: 4.7996
Epoch 13/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 108s 3s/step - accuracy: 0.9501 - loss: 0.1690 - val_accuracy: 0.1250 - val_loss: 7.1115
Epoch 14/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 99s 3s/step - accuracy: 0.9436 - loss: 0.1612 - val_accuracy: 0.0938 - val_loss: 2.9126
Epoch 15/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1639 - val_loss: 3.1586
Epoch 16/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9576 - loss: 0.1224 - val_accuracy: 0.1562 - val_loss: 5.2999
Epoch 17/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 99s 3s/step - accuracy: 0.9546 - loss: 0.1341 - val_accuracy: 0.2656 - val_loss: 2.9544
Epoch 18/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 63ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.2422 - val_loss: 3.0065
Epoch 19/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 117s 3s/step - accuracy: 0.9539 - loss: 0.1364 - val_accuracy: 0.3594 - val_loss: 2.2156
Epoch 20/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9698 - loss: 0.0903 - val_accuracy: 0.3607 - val_loss: 2.2665
Epoch 21/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 3s 95ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.3984 - val_loss: 2.2333
Epoch 22/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 112s 3s/step - accuracy: 0.9586 - loss: 0.1315 - val_accuracy: 0.2734 - val_loss: 3.1006
Epoch 23/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 100s 3s/step - accuracy: 0.9690 - loss: 0.0914 - val_accuracy: 0.1406 - val_loss: 8.3236
Epoch 24/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 64ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1484 - val_loss: 7.7363
Epoch 25/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9597 - loss: 0.1182 - val_accuracy: 0.2131 - val_loss: 4.4671
Epoch 26/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 100s 3s/step - accuracy: 0.9551 - loss: 0.1098 - val_accuracy: 0.2969 - val_loss: 5.1090
Epoch 27/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 62ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.2422 - val_loss: 5.5664
Epoch 28/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 104s 3s/step - accuracy: 0.9691 - loss: 0.0851 - val_accuracy: 0.3438 - val_loss: 5.1524
Epoch 29/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.9559 - loss: 0.1273 - val_accuracy: 0.6016 - val_loss: 1.6378
Epoch 30/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5246 - val_loss: 1.7923
Epoch 31/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9711 - loss: 0.0962 - val_accuracy: 0.3906 - val_loss: 3.0544
Epoch 32/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.9670 - loss: 0.0949 - val_accuracy: 0.6875 - val_loss: 1.1604
Epoch 33/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 65ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7344 - val_loss: 1.1340
Epoch 34/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9729 - loss: 0.0753 - val_accuracy: 0.5469 - val_loss: 2.1677
Epoch 35/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.9640 - loss: 0.0922 - val_accuracy: 0.8033 - val_loss: 1.1701
Epoch 36/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 3s 83ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7422 - val_loss: 1.2480
Epoch 37/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 104s 3s/step - accuracy: 0.9721 - loss: 0.0733 - val_accuracy: 0.7188 - val_loss: 1.3276
Epoch 38/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 98s 3s/step - accuracy: 0.9714 - loss: 0.0812 - val_accuracy: 0.7188 - val_loss: 1.5875
Epoch 39/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 64ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7031 - val_loss: 1.5783
Epoch 40/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 103s 3s/step - accuracy: 0.9769 - loss: 0.0631 - val_accuracy: 0.8361 - val_loss: 0.3836
Epoch 41/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 98s 3s/step - accuracy: 0.9772 - loss: 0.0764 - val_accuracy: 0.7344 - val_loss: 1.1949
Epoch 42/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 65ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7969 - val_loss: 0.9238
Epoch 43/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9673 - loss: 0.1129 - val_accuracy: 0.8438 - val_loss: 0.4739
Epoch 44/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.9638 - loss: 0.0925 - val_accuracy: 0.4531 - val_loss: 2.6108
Epoch 45/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 35ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4590 - val_loss: 2.5071
Epoch 46/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9727 - loss: 0.0740 - val_accuracy: 0.2266 - val_loss: 3.5780
Epoch 47/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.9684 - loss: 0.0857 - val_accuracy: 0.9062 - val_loss: 0.4636
Epoch 48/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 64ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.8828 - val_loss: 0.4826
Epoch 49/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9912 - loss: 0.0501 - val_accuracy: 0.8906 - val_loss: 0.5076
Epoch 50/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.9843 - loss: 0.0448 - val_accuracy: 0.7377 - val_loss: 1.2556
In [45]:
# Evaluate the model
test_generator = valid_datagen.flow(X_test, y_test_new, batch_size=64)
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(X_test) // 64)
print(f"Test Accuracy: {test_accuracy}")
10/10 ━━━━━━━━━━━━━━━━━━━━ 2s 156ms/step - accuracy: 0.1679 - loss: 1.3863
Test Accuracy: 0.16249999403953552

The training and validation loss plots, as well as the training and validation accuracy plots, show that there might be issues with overfitting and instability during training.

In [41]:
# Plotting training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
In [42]:
# Plotting training and validation accuracy
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
7.2 Model 2 adjusted for overfitting¶
In [46]:
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Dropout

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu', kernel_regularizer=l2(0.01))(x)
x = Dropout(0.5)(x)
predictions = Dense(num_classes, activation='softmax', kernel_regularizer=l2(0.01))(x)

model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_10 (Conv2D)              │ (None, 220, 220, 16)   │         1,216 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_10 (MaxPooling2D) │ (None, 110, 110, 16)   │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_11 (Conv2D)              │ (None, 108, 108, 32)   │         4,640 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_11 (MaxPooling2D) │ (None, 54, 54, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_12 (Conv2D)              │ (None, 52, 52, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_12 (MaxPooling2D) │ (None, 26, 26, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_13 (Conv2D)              │ (None, 24, 24, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_13 (MaxPooling2D) │ (None, 12, 12, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_14 (Conv2D)              │ (None, 10, 10, 256)    │       295,168 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_14 (MaxPooling2D) │ (None, 5, 5, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_2 (Flatten)             │ (None, 6400)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_9 (Dense)                 │ (None, 512)            │     3,277,312 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_10 (Dense)                │ (None, 4)              │         2,052 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 3,672,740 (14.01 MB)
 Trainable params: 3,672,740 (14.01 MB)
 Non-trainable params: 0 (0.00 B)
In [47]:
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
In [48]:
# Train the model
history = model.fit(train_generator,
                    epochs=50,
                    validation_data=val_generator,
                    steps_per_epoch=len(X_train) // 64,
                    validation_steps=len(X_valid) // 64,
                    callbacks=[reduce_lr, early_stopping])
Epoch 1/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 42s 970ms/step - accuracy: 0.2520 - loss: 1.3695 - val_accuracy: 0.3281 - val_loss: 1.3465 - learning_rate: 1.0000e-04
Epoch 2/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 790ms/step - accuracy: 0.3925 - loss: 1.2717 - val_accuracy: 0.4609 - val_loss: 1.1761 - learning_rate: 1.0000e-04
Epoch 3/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.3828 - val_loss: 1.2153 - learning_rate: 1.0000e-04
Epoch 4/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 38s 914ms/step - accuracy: 0.4603 - loss: 1.1935 - val_accuracy: 0.4609 - val_loss: 1.2367 - learning_rate: 1.0000e-04
Epoch 5/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 766ms/step - accuracy: 0.5027 - loss: 1.1517 - val_accuracy: 0.4754 - val_loss: 1.1391 - learning_rate: 1.0000e-04
Epoch 6/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 27ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4375 - val_loss: 1.1860 - learning_rate: 1.0000e-04
Epoch 7/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 38s 917ms/step - accuracy: 0.5470 - loss: 1.0842 - val_accuracy: 0.5000 - val_loss: 1.1957 - learning_rate: 1.0000e-04
Epoch 8/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 770ms/step - accuracy: 0.5352 - loss: 1.0778 - val_accuracy: 0.5547 - val_loss: 1.1083 - learning_rate: 1.0000e-04
Epoch 9/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4922 - val_loss: 1.2243 - learning_rate: 1.0000e-04
Epoch 10/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 885ms/step - accuracy: 0.5688 - loss: 1.0169 - val_accuracy: 0.5082 - val_loss: 1.0406 - learning_rate: 1.0000e-04
Epoch 11/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 792ms/step - accuracy: 0.5924 - loss: 0.9813 - val_accuracy: 0.5078 - val_loss: 1.1500 - learning_rate: 1.0000e-04
Epoch 12/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 23ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5000 - val_loss: 1.1776 - learning_rate: 1.0000e-04
Epoch 13/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 895ms/step - accuracy: 0.6084 - loss: 0.9614 - val_accuracy: 0.5469 - val_loss: 1.1129 - learning_rate: 1.0000e-04
Epoch 14/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 766ms/step - accuracy: 0.5915 - loss: 0.9620 - val_accuracy: 0.5234 - val_loss: 1.0463 - learning_rate: 1.0000e-04
Epoch 15/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.6557 - val_loss: 0.9987 - learning_rate: 1.0000e-04
Epoch 16/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 914ms/step - accuracy: 0.6377 - loss: 0.8989 - val_accuracy: 0.4844 - val_loss: 1.1103 - learning_rate: 1.0000e-04
Epoch 17/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 772ms/step - accuracy: 0.6393 - loss: 0.9007 - val_accuracy: 0.5156 - val_loss: 1.2403 - learning_rate: 1.0000e-04
Epoch 18/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4531 - val_loss: 1.3733 - learning_rate: 1.0000e-04
Epoch 19/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 883ms/step - accuracy: 0.6310 - loss: 0.8554 - val_accuracy: 0.4766 - val_loss: 1.1770 - learning_rate: 1.0000e-04
Epoch 20/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 779ms/step - accuracy: 0.6548 - loss: 0.8691 - val_accuracy: 0.5410 - val_loss: 0.9768 - learning_rate: 1.0000e-04
Epoch 21/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5000 - val_loss: 1.0790 - learning_rate: 1.0000e-04
Epoch 22/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 887ms/step - accuracy: 0.6322 - loss: 0.8440 - val_accuracy: 0.5625 - val_loss: 1.0311 - learning_rate: 1.0000e-04
Epoch 23/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 787ms/step - accuracy: 0.6509 - loss: 0.8023 - val_accuracy: 0.4375 - val_loss: 1.4593 - learning_rate: 1.0000e-04
Epoch 24/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4062 - val_loss: 1.4555 - learning_rate: 1.0000e-04
Epoch 25/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 35s 868ms/step - accuracy: 0.6188 - loss: 0.8770 - val_accuracy: 0.4754 - val_loss: 1.2434 - learning_rate: 1.0000e-04
Epoch 26/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 29s 795ms/step - accuracy: 0.6503 - loss: 0.7951 - val_accuracy: 0.5469 - val_loss: 0.9752 - learning_rate: 2.0000e-05
Epoch 27/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4844 - val_loss: 1.1067 - learning_rate: 2.0000e-05
Epoch 28/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 901ms/step - accuracy: 0.6912 - loss: 0.7337 - val_accuracy: 0.4766 - val_loss: 1.1368 - learning_rate: 2.0000e-05
Epoch 29/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 768ms/step - accuracy: 0.6807 - loss: 0.7792 - val_accuracy: 0.5156 - val_loss: 0.9521 - learning_rate: 2.0000e-05
Epoch 30/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5082 - val_loss: 1.1625 - learning_rate: 2.0000e-05
Epoch 31/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 909ms/step - accuracy: 0.6708 - loss: 0.7774 - val_accuracy: 0.6641 - val_loss: 0.9106 - learning_rate: 2.0000e-05
Epoch 32/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 772ms/step - accuracy: 0.7000 - loss: 0.7276 - val_accuracy: 0.6094 - val_loss: 0.9920 - learning_rate: 2.0000e-05
Epoch 33/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 24ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.6328 - val_loss: 0.9810 - learning_rate: 2.0000e-05
Epoch 34/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 888ms/step - accuracy: 0.7285 - loss: 0.7102 - val_accuracy: 0.5703 - val_loss: 1.0284 - learning_rate: 2.0000e-05
Epoch 35/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 762ms/step - accuracy: 0.6813 - loss: 0.7785 - val_accuracy: 0.5738 - val_loss: 0.9935 - learning_rate: 2.0000e-05
Epoch 36/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 26ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5938 - val_loss: 1.0313 - learning_rate: 2.0000e-05
Epoch 37/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 890ms/step - accuracy: 0.7046 - loss: 0.7486 - val_accuracy: 0.5859 - val_loss: 0.9561 - learning_rate: 1.0000e-05
Epoch 38/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 773ms/step - accuracy: 0.6859 - loss: 0.7417 - val_accuracy: 0.4922 - val_loss: 1.3085 - learning_rate: 1.0000e-05
Epoch 39/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 24ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5547 - val_loss: 1.0669 - learning_rate: 1.0000e-05
Epoch 40/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 887ms/step - accuracy: 0.6863 - loss: 0.7276 - val_accuracy: 0.4918 - val_loss: 1.1799 - learning_rate: 1.0000e-05
Epoch 41/50
36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 787ms/step - accuracy: 0.7090 - loss: 0.7319 - val_accuracy: 0.5547 - val_loss: 0.9836 - learning_rate: 1.0000e-05
In [50]:
# Evaluate the model
test_generator = valid_datagen.flow(X_test, y_test_new, batch_size=64)
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(X_test) // 64)
print(f"Test Accuracy: {test_accuracy}")
10/10 ━━━━━━━━━━━━━━━━━━━━ 2s 161ms/step - accuracy: 0.2872 - loss: 2.4204
Test Accuracy: 0.3031249940395355

The training and validation loss and accuracy plots indicate that the model is facing significant instability during training.

In [51]:
# Plotting training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
In [52]:
# Plotting training and validation accuracy
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Step 8 Conclusion¶

CNN Architecture 2 Run 3 has the highest accuracy

In [22]:
import pandas as pd

# Define data for Architecture 1
data2 = {
    'Steps per Epoch': [5, 50, 100, 100, 37],
    'Epochs': [10, 50, 50, 35, 100],
    'Accuracy': [0.245, 0.9173, 0.9449, 0.928, 0.9096],
    'Notes': ['', '', '', '', 'early stopped at epoch = 46']
}

# Define data for Architecture 2
data1 = {
    'Steps per Epoch': [100],
    'Epochs': [35],
    'Accuracy': [0.925],
    'Notes': ['']
}

# Create DataFrames
df1 = pd.DataFrame(data1, index=[f"Architecture 1 - Run {i+1}" for i in range(len(data1['Steps per Epoch']))])
df2 = pd.DataFrame(data2, index=[f"Architecture 2 (L2 Regularization) - Run {i+1}" for i in range(len(data2['Steps per Epoch']))])

# Concatenate both DataFrames
df = pd.concat([df1, df2])

df
Out[22]:
Steps per Epoch Epochs Accuracy Notes
Architecture 1 - Run 1 100 35 0.9250
Architecture 2 (L2 Regularization) - Run 1 5 10 0.2450
Architecture 2 (L2 Regularization) - Run 2 50 50 0.9173
Architecture 2 (L2 Regularization) - Run 3 100 50 0.9449
Architecture 2 (L2 Regularization) - Run 4 100 35 0.9280
Architecture 2 (L2 Regularization) - Run 5 37 100 0.9096 early stopped at epoch = 46

Model Performance Analysis

Training and Validation Loss: The plot shows that while the training loss has consistently decreased and flattened (indicating good learning), the validation loss has some fluctuations but generally follows the training loss closely without diverging too much. This suggests that the model is not overfitting significantly.

  1. Accuracy Metrics:

Validation Accuracy: Peaked at approximately 95.04% during training, which is quite high. Test Accuracy: Even higher at 94.49%. This consistency between validation and test accuracy is a good sign of the model's ability to generalize well.

  1. Classification Report:

Precision and Recall: Very high across all classes, with Class 3 achieving perfect recall (1.00). This indicates that the model is very effective in identifying true positives for Class 3 without any false negatives.

F1-Score: Also high across all classes, suggesting a good balance between precision and recall. The weighted averages for accuracy, precision, recall, and F1-score are all above 0.94, which is excellent.

Observations

  1. Model Stability: The model demonstrates stable performance across metrics, which is indicative of robust learning capabilities.

  2. Loss Fluctuations: The fluctuations in validation loss could be indicative of potential minor overfitting or could simply be a result of the model navigating through complex loss landscapes. However, as they do not diverge significantly, this is not a major concern currently.

Step 9 Future Work¶

Address Class Imbalance:¶

Augmentation for Underrepresented Classes: Increase the number of augmented images for the underrepresented class (no_tumor) to balance the dataset.

Class Weights: Utilize class weights in the model training process to give more importance to underrepresented classes during the loss calculation.

Oversampling/Undersampling: Consider oversampling the minority class or undersampling the majority classes.

Enhance Data Augmentation:¶

The current augmentation strategy is robust, but we can experiment with less aggressive transformations for brain images, where orientation and structure are important. For example, a high rotation range might not be appropriate as brain tumors and their structures could be highly orientation-specific.

Modify the Network Architecture:¶

Depth and Complexity: As you're dealing with complex medical images, consider gradually increasing the complexity of the CNN. Incorporate deeper layers or additional convolutional blocks to capture more complex features.

Advanced Architectures: Explore more sophisticated architectures like ResNet, Inception, or DenseNet, which might be more effective for medical image analysis due to their deeper and more complex structures.

Continuous Monitoring and Feedback Loop:¶

Once deployed, continuously monitor the model’s performance and establish a feedback loop with medical professionals to collect insights and further improve the model.

Clinical Validation:¶

Before full deployment, ensure that the model undergoes thorough clinical validation to meet regulatory standards and to confirm that it performs well across different demographics and equipment variations.

In [ ]: